Unverified Commit 8a599895 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Re-organize file and align results (#99)

* Move template modification from nn to fastnn (#97)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* Align results and update unit tests (#98)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* update evoformer stack test

* update test

* update msa_att_row

* update msa_att_col

* update evoformer and evo-stack

* update evoformer

* update extramsa

* move model loading out of the loop

* finish template test

* update test

* Move template modification from nn to fastnn (#84)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* move model out of function

* only test inference

* remove cache in build

* update test inference

* restore changes

* restore build changes

* update inference and evoformer stack

* fix some bug

* update test

* update evoformer stack test

* update test

* update test

* fix test

* update test

* update test

* update input embedder

* update embedder

* reset core

* update test

* support template multimer in inject_nn
parent 7a69a181
from .msa import MSAStack, ExtraMSAStack from .msa import MSACore, ExtraMSACore, ExtraMSABlock, ExtraMSAStack
from .ops import OutProductMean, set_chunk_size from .ops import OutProductMean, set_chunk_size
from .triangle import PairStack from .triangle import PairCore
from .evoformer import Evoformer from .evoformer import Evoformer, EvoformerStack
from .blocks import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock from .template import TemplatePairBlock, TemplatePairStack
__all__ = ['MSAStack', 'ExtraMSAStack', 'OutProductMean', 'PairStack', 'Evoformer',
'set_chunk_size', 'EvoformerBlock', 'ExtraMSABlock', 'TemplatePairStackBlock'] __all__ = [
'MSACore', 'OutProductMean', 'PairCore', 'set_chunk_size',
'TemplatePairBlock', 'TemplatePairStack',
'ExtraMSACore', 'ExtraMSABlock', 'ExtraMSAStack',
'Evoformer', 'EvoformerStack',
]
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from functools import partial
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.fastnn.ops import Linear
from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
inplace=False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
tt = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype).to(z.device)
tt = self.template_pair_embedder(tt)
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z]
if inplace:
tt = [tt]
t[i] = self.template_pair_stack.inplace(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z]
if inplace:
z = self.template_pointwise_att.inplace(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
else:
z = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
ret = {}
ret["template_pair_embedding"] = z
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
return ret
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, initializer="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, initializer="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, initializer="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Dict
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import dgram_from_positions
from fastfold.model.fastnn.ops import Linear, LayerNorm
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super().__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super().__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_features = self.template_single_embedder(
template_features
)
template_features = torch.nn.functional.relu(
template_features
)
template_features = self.template_projector(
template_features,
)
out["template_single_embedding"] = (
template_features
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
inplace
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
if not inplace:
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"])
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"])
return template_embeds
from typing import Optional, Tuple
from functools import partial
import torch
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack from fastfold.utils.checkpointing import checkpoint_blocks
class Evoformer(nn.Module): class Evoformer(nn.Module):
def __init__(self, d_node=256, d_pair=128): def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(Evoformer, self).__init__() super(Evoformer, self).__init__()
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) self.first_block = first_block
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) self.last_block = last_block
self.pair_stack = PairStack(d_pair=d_pair)
self.msa = MSACore(c_m, c_z, p_drop=0.15)
def forward(self, node, pair, node_mask, pair_mask): self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
node = self.msa_stack(node, pair, node_mask) self.pair = PairCore(d_pair=c_z)
pair = self.communication(node, node_mask, pair) self.is_multimer = is_multimer
node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask) def forward(
node = All_to_All_Async_Opp.apply(node, work, 1, 2) self,
return node, pair m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_s: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = Evoformer(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
...@@ -13,16 +13,21 @@ ...@@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fastfold.model.fastnn.kernel import LayerNorm from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention from fastfold.model.fastnn.kernel import LayerNorm, bias_dropout_add
from fastfold.model.fastnn.kernel import bias_dropout_add from fastfold.model.fastnn.ops import (ChunkMSARowAttentionWithPairBias, ChunkTransition,
SelfAttention, GlobalAttention, Transition,
ChunkMSAColumnGlobalAttention, OutProductMean)
from fastfold.distributed import scatter, row_to_col from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async from fastfold.distributed.comm import gather, scatter, row_to_col, scatter
from fastfold.distributed.comm_async import gather_async, All_to_All_Async, All_to_All_Async_Opp
from fastfold.model.fastnn.triangle import PairCore
class MSARowAttentionWithPairBias(nn.Module): class MSARowAttentionWithPairBias(nn.Module):
...@@ -120,10 +125,10 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -120,10 +125,10 @@ class MSAColumnGlobalAttention(nn.Module):
return M_raw + M return M_raw + M
class MSAStack(nn.Module): class MSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__() super(MSACore, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair, d_pair=d_pair,
...@@ -146,9 +151,9 @@ class MSAStack(nn.Module): ...@@ -146,9 +151,9 @@ class MSAStack(nn.Module):
return node return node
class ExtraMSAStack(nn.Module): class ExtraMSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15): def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSAStack, self).__init__() super(ExtraMSACore, self).__init__()
self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias( self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8 d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
...@@ -180,3 +185,280 @@ class ExtraMSAStack(nn.Module): ...@@ -180,3 +185,280 @@ class ExtraMSAStack(nn.Module):
node = self.MSATransition.inplace(node) node = self.MSATransition.inplace(node)
return node return node
class ExtraMSABlock(nn.Module):
def __init__(
self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer=False
):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairCore(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m = scatter(m, dim=1) if not self.is_multimer else scatter(m, dim=2)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = gather(m, dim=1) if not self.is_multimer else gather(m, dim=2)
z = gather(z, dim=1)
m = m[:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z = z[:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m = m.squeeze(0)
z = z.squeeze(0)
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z[0] = torch.nn.functional.pad(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = [z[0].clone()]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
return m, z
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
no_blocks: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
...@@ -1166,7 +1166,7 @@ class GlobalAttention(nn.Module): ...@@ -1166,7 +1166,7 @@ class GlobalAttention(nn.Module):
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / ( q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps torch.sum(mask_part, dim=-1)[..., None] + self.eps
) )
q = q * self.scaling
q = self.to_q(q) q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1)) q = q.view(q.shape[:-1] + (self.n_head, -1))
...@@ -1189,3 +1189,102 @@ class GlobalAttention(nn.Module): ...@@ -1189,3 +1189,102 @@ class GlobalAttention(nn.Module):
m = torch.cat(output, dim=1) m = torch.cat(output, dim=1)
return m return m
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
ri = ri.type(tf_emb_i.dtype)
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
pair_emb = d[..., None] - reshaped_bins
pair_emb = torch.argmin(torch.abs(pair_emb), dim=-1)
pair_emb = nn.functional.one_hot(pair_emb, num_classes=len(boundaries)).float().type(ri.dtype)
pair_emb = self.linear_relpos(pair_emb)
pair_emb += tf_emb_i[..., None, :]
pair_emb += tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
...@@ -5,7 +5,11 @@ import torch.nn as nn ...@@ -5,7 +5,11 @@ import torch.nn as nn
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import Linear, SelfAttention, ChunkTransition, ChunkTriangleAttentionEndingNode, AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming, ChunkTriangleAttentionStartingNode from fastfold.model.fastnn.ops import (Linear, SelfAttention, ChunkTransition,
ChunkTriangleAttentionStartingNode,
ChunkTriangleAttentionEndingNode,
AsyncChunkTriangleMultiplicationOutgoing,
AsyncChunkTriangleMultiplicationIncoming)
from fastfold.distributed.comm_async import gather_async_opp, gather_async from fastfold.distributed.comm_async import gather_async_opp, gather_async
...@@ -209,10 +213,10 @@ class TriangleAttentionEndingNode(nn.Module): ...@@ -209,10 +213,10 @@ class TriangleAttentionEndingNode(nn.Module):
training=self.training) training=self.training)
class PairStack(nn.Module): class PairCore(nn.Module):
def __init__(self, d_pair, p_drop=0.25): def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__() super(PairCore, self).__init__()
self.d_pair = d_pair self.d_pair = d_pair
self.n_head = 4 self.n_head = 4
......
...@@ -36,9 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack ...@@ -36,9 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule from fastfold.model.nn.structure_module import StructureModule
from fastfold.model.loss import (
compute_plddt,
)
from fastfold.utils.tensor_utils import ( from fastfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
...@@ -279,7 +276,8 @@ class AlphaFold(nn.Module): ...@@ -279,7 +276,8 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
else: else:
template_embeds, z = self.template_embedder( if self.globals.inplace:
template_embeds = self.template_embedder(
template_feats, template_feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
...@@ -287,7 +285,16 @@ class AlphaFold(nn.Module): ...@@ -287,7 +285,16 @@ class AlphaFold(nn.Module):
self.globals.chunk_size, self.globals.chunk_size,
inplace=self.globals.inplace inplace=self.globals.inplace
) )
z = template_embeds["template_pair_embedding"]
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size,
)
z = z + template_embeds["template_pair_embedding"]
if( if(
self.config.template.embed_angles or self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled) (self.globals.is_multimer and self.config.template.enabled)
......
...@@ -25,7 +25,6 @@ from fastfold.utils.feats import ( ...@@ -25,7 +25,6 @@ from fastfold.utils.feats import (
) )
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.ops import RecyclingEmbedder
from fastfold.model.nn.template import ( from fastfold.model.nn.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
...@@ -123,8 +122,8 @@ class InputEmbedder(nn.Module): ...@@ -123,8 +122,8 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype)) pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -138,6 +137,101 @@ class InputEmbedder(nn.Module): ...@@ -138,6 +137,101 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateEmbedder(nn.Module): class TemplateEmbedder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(TemplateEmbedder, self).__init__() super(TemplateEmbedder, self).__init__()
...@@ -162,18 +256,11 @@ class TemplateEmbedder(nn.Module): ...@@ -162,18 +256,11 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True, _mask_trans=True
inplace=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -193,57 +280,48 @@ class TemplateEmbedder(nn.Module): ...@@ -193,57 +280,48 @@ class TemplateEmbedder(nn.Module):
single_template_embeds["angle"] = a single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
tt = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf, inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps, eps=self.config.eps,
**self.config.distogram, **self.config.distogram,
).to(z.dtype).to(z.device) ).to(z.dtype)
t = self.template_pair_embedder(t)
tt = self.template_pair_embedder(tt) single_template_embeds.update({"pair": t})
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
if inplace: t = self.template_pair_stack(
tt = [tt] template_embeds["pair"],
t[i] = self.template_pair_stack.inplace(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size, chunk_size=chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if self.config.embed_angles: if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"] ret["template_single_embedding"] = template_embeds["angle"]
return ret, z ret.update({"template_pair_embedding": t})
return ret
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
......
...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask = template_chi_mask[..., 0] template_mask = template_chi_mask[..., 0]
template_features = self.template_single_embedder( template_activations = self.template_single_embedder(
template_features template_features
) )
template_features = torch.nn.functional.relu( template_activations = torch.nn.functional.relu(
template_features template_activations
) )
template_features = self.template_projector( template_activations = self.template_projector(
template_features, template_activations,
) )
out["template_single_embedding"] = ( out["template_single_embedding"] = (
template_features template_activations
) )
out["template_mask"] = template_mask out["template_mask"] = template_mask
...@@ -296,7 +296,6 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -296,7 +296,6 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
inplace
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
...@@ -308,6 +307,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -308,6 +307,7 @@ class TemplateEmbedderMultimer(nn.Module):
) )
single_template_embeds = {} single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = ( template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"], single_template_feats["template_pseudo_beta"],
...@@ -361,27 +361,17 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -361,27 +361,17 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds, template_embeds,
) )
if not inplace:
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["template_pair_embedding"], template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
_mask_trans=False, _mask_trans=False,
) )
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z] # [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ t = torch.sum(t, dim=-4) / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"]) t = torch.nn.functional.relu(t)
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"]) t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds return template_embeds
\ No newline at end of file
...@@ -539,60 +539,6 @@ class EvoformerStack(nn.Module): ...@@ -539,60 +539,6 @@ class EvoformerStack(nn.Module):
return m, z, s return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
""" """
...@@ -688,49 +634,3 @@ class ExtraMSAStack(nn.Module): ...@@ -688,49 +634,3 @@ class ExtraMSAStack(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return z return z
\ No newline at end of file
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
\ No newline at end of file
...@@ -122,26 +122,11 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -122,26 +122,11 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
biases = [bias] biases = [bias]
if chunk_size is not None: if chunk_size is not None:
para_dim_t0 = t.shape[0] z = self._chunk(z, t, biases, chunk_size)
para_dim_t1 = t.shape[1]
chunk_size_t = chunk_size * 4
mask = torch.sum(template_mask.to(z.device)) > 0
for ti in range(0, para_dim_t0, chunk_size_t):
t0 = t[ti:ti + chunk_size_t, :, :, :]
t0 = t0.to(z.device)
para_dim_t_part = t0.shape[0]
for i in range(0, para_dim_t_part, chunk_size):
for j in range(0, para_dim_t1, chunk_size):
z[i:i + chunk_size, j:j + chunk_size, :, :] += self.mha(
q_x=z[i + ti:i + ti + chunk_size, j:j + chunk_size, :, :], kv_x=t0[i:i + chunk_size, j:j + chunk_size, :, :], biases=biases
) * mask
else: else:
t = self.mha(q_x=z, kv_x=t, biases=biases) z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
t = t * (torch.sum(template_mask) > 0)
z = z + t
# [*, N_res, N_res, C_z]
z = z.squeeze(-2) z = z.squeeze(-2)
return z return z
...@@ -368,48 +353,7 @@ class TemplatePairStack(nn.Module): ...@@ -368,48 +353,7 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
chunk_size = t.shape[0]
for i in range(0, t.shape[0], chunk_size):
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t
def inplace( t = self.layer_norm(t)
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t[0].shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b.inplace,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if chunk_size is None:
chunk_size = t[0].shape[0]
for i in range(0, t[0].shape[0], chunk_size):
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t return t
\ No newline at end of file
from .inject_fastnn import inject_fastnn
__all__ = ['inject_fastnn']
\ No newline at end of file
...@@ -13,9 +13,13 @@ ...@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from fastfold.model.fastnn import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from fastfold.model.fastnn.embedders import TemplateEmbedder
from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
def copy_layernorm(model_fast, model_ori): def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight) model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias) model_fast.bias.copy_(model_ori.bias)
...@@ -27,6 +31,14 @@ def copy_linear(model_fast, model_ori): ...@@ -27,6 +31,14 @@ def copy_linear(model_fast, model_ori):
model_fast.bias.copy_(model_ori.bias) model_fast.bias.copy_(model_ori.bias)
def copy_native_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
try:
model_fast.bias.copy_(model_ori.bias)
except:
pass
def copy_kv_linear(model_fast, ori_k, ori_v): def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0)) model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
...@@ -77,32 +89,41 @@ def copy_triangle_att(model_fast, model_ori): ...@@ -77,32 +89,41 @@ def copy_triangle_att(model_fast, model_ori):
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias) model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_native_att(model_fast, model_ori):
copy_native_linear(model_fast.linear_q, model_ori.linear_q)
copy_native_linear(model_fast.linear_k, model_ori.linear_k)
copy_native_linear(model_fast.linear_v, model_ori.linear_v)
copy_native_linear(model_fast.linear_o, model_ori.linear_o)
if model_ori.gating:
copy_native_linear(model_fast.linear_g, model_ori.linear_g)
def copy_evoformer_para(block_fast, block_ori): def copy_evoformer_para(block_fast, block_ori):
# msa_stack # msa_stack
# MSARowAttentionWithPairBias # MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM, copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m) block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ, copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z) block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa_stack.MSARowAttentionWithPairBias.attention, copy_attention(block_fast.msa.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha) block_ori.msa_att_row.mha)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_( block_fast.msa.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight) block_ori.msa_att_row.linear_z.weight)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_( block_fast.msa.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias) block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention # MSAColumnAttention
copy_layernorm(block_fast.msa_stack.MSAColumnAttention.layernormM, copy_layernorm(block_fast.msa.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m) block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa_stack.MSAColumnAttention.attention, copy_attention(block_fast.msa.MSAColumnAttention.attention,
block_ori.msa_att_col._msa_att.mha) block_ori.msa_att_col._msa_att.mha)
# MSATransition # MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition) copy_transition(block_fast.msa.MSATransition, block_ori.core.msa_transition)
# communication # communication
copy_layernorm(block_fast.communication.layernormM, copy_layernorm(block_fast.communication.layernormM,
...@@ -113,16 +134,16 @@ def copy_evoformer_para(block_fast, block_ori): ...@@ -113,16 +134,16 @@ def copy_evoformer_para(block_fast, block_ori):
# pair_stack # pair_stack
# TriangleMultiplicationOutgoing # TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out) copy_triangle(block_fast.pair.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming # TriangleMultiplicationIncoming
copy_triangle(block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in) copy_triangle(block_fast.pair.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode # TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair_stack.TriangleAttentionStartingNode, copy_triangle_att(block_fast.pair.TriangleAttentionStartingNode,
block_ori.core.tri_att_start) block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end) copy_triangle_att(block_fast.pair.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition) copy_transition(block_fast.pair.PairTransition, block_ori.core.pair_transition)
def copy_global_attention(model_fast, model_ori): def copy_global_attention(model_fast, model_ori):
...@@ -222,87 +243,175 @@ def copy_template_pair_stack_para(block_fast, block_ori): ...@@ -222,87 +243,175 @@ def copy_template_pair_stack_para(block_fast, block_ori):
copy_transition(block_fast.PairTransition, block_ori.pair_transition) copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def inject_evoformer(model): def copy_template_pair_block_para(fast_module, target_module):
with torch.no_grad(): with torch.no_grad():
fastfold_blocks = nn.ModuleList() for ori_block, fast_block in zip(target_module.blocks, fast_module.blocks):
for block_id, ori_block in enumerate(model.evoformer.blocks): copy_template_pair_stack_para(fast_block, ori_block)
c_m = ori_block.msa_att_row.c_in if ori_block.training == False:
c_z = ori_block.msa_att_row.c_z fast_block.eval()
is_multimer = ori_block.is_multimer
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z, def copy_template_para(block_fast, block_ori):
first_block=(block_id == 0), # TemplateAngleEmbedder
last_block=(block_id == len(model.evoformer.blocks) - 1), copy_linear(block_fast.template_angle_embedder.linear_1,
is_multimer=is_multimer, block_ori.template_angle_embedder.linear_1)
) copy_linear(block_fast.template_angle_embedder.linear_2,
block_ori.template_angle_embedder.linear_2)
copy_evoformer_para(fastfold_block, ori_block)
# TemplatePairEmbedder
fastfold_blocks.append(fastfold_block) copy_linear(block_fast.template_pair_embedder.linear,
block_ori.template_pair_embedder.linear)
model.evoformer.blocks = fastfold_blocks
# TemplatePairStack
return model copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# TemplatePointwiseAttention
copy_native_att(block_fast.template_pointwise_att.mha,
block_ori.template_pointwise_att.mha)
def copy_template_multimer_para(block_fast, block_ori):
# TemplatePairEmbedderMultimer
copy_linear(block_fast.template_pair_embedder.dgram_linear,
block_ori.template_pair_embedder.dgram_linear)
copy_linear(block_fast.template_pair_embedder.aatype_linear_1,
block_ori.template_pair_embedder.aatype_linear_1)
copy_linear(block_fast.template_pair_embedder.aatype_linear_2,
block_ori.template_pair_embedder.aatype_linear_2)
copy_layernorm(block_fast.template_pair_embedder.query_embedding_layer_norm,
block_ori.template_pair_embedder.query_embedding_layer_norm)
copy_linear(block_fast.template_pair_embedder.query_embedding_linear,
block_ori.template_pair_embedder.query_embedding_linear)
copy_linear(block_fast.template_pair_embedder.pseudo_beta_mask_linear,
block_ori.template_pair_embedder.pseudo_beta_mask_linear)
copy_linear(block_fast.template_pair_embedder.x_linear,
block_ori.template_pair_embedder.x_linear)
copy_linear(block_fast.template_pair_embedder.y_linear,
block_ori.template_pair_embedder.y_linear)
copy_linear(block_fast.template_pair_embedder.z_linear,
block_ori.template_pair_embedder.z_linear)
copy_linear(block_fast.template_pair_embedder.backbone_mask_linear,
block_ori.template_pair_embedder.backbone_mask_linear)
# TemplateSingleEmbedderMultimer
copy_linear(block_fast.template_single_embedder.template_single_embedder,
block_ori.template_single_embedder.template_single_embedder)
copy_linear(block_fast.template_single_embedder.template_projector,
block_ori.template_single_embedder.template_projector)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack,
block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# linear_t
copy_linear(block_fast.linear_t, block_ori.linear_t)
def inject_extraMsaBlock(model): def inject_evoformer(model):
with torch.no_grad(): with torch.no_grad():
new_model_blocks = nn.ModuleList() target_module = model.evoformer
for block_id, ori_block in enumerate(model.extra_msa_stack.blocks): fast_module = EvoformerStack(
c_m = ori_block.msa_att_row.c_in c_m=target_module.blocks[0].msa_att_row.c_in,
c_z = ori_block.msa_att_row.c_z c_z=target_module.blocks[0].msa_att_row.c_z,
is_multimer = ori_block.is_multimer c_s=target_module.linear.out_features,
new_model_block = ExtraMSABlock( no_blocks=len(target_module.blocks),
c_m=c_m, blocks_per_ckpt=target_module.blocks_per_ckpt,
c_z=c_z, clear_cache_between_blocks=target_module.clear_cache_between_blocks,
first_block=(block_id == 0), is_multimer=target_module.blocks[0].is_multimer,
last_block=(block_id == len(model.extra_msa_stack.blocks) - 1),
is_multimer=is_multimer
) )
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_evoformer_para(fast_block, target_block)
if target_block.training == False:
fast_block.eval()
copy_linear(fast_module.linear, target_module.linear)
model.evoformer = fast_module
copy_extra_msa_para(new_model_block, ori_block)
if ori_block.training == False:
new_model_block.eval()
new_model_blocks.append(new_model_block)
model.extra_msa_stack.blocks = new_model_blocks
def inject_extramsa(model):
def inject_templatePairBlock(model):
with torch.no_grad(): with torch.no_grad():
target_module = model.template_embedder.template_pair_stack.blocks target_module = model.extra_msa_stack
fastfold_blocks = nn.ModuleList() fast_module = ExtraMSAStack(
for block_id, ori_block in enumerate(target_module): c_m=target_module.blocks[0].msa_att_row.c_in,
c_t = ori_block.c_t c_z=target_module.blocks[0].msa_att_row.c_z,
c_hidden_tri_att = ori_block.c_hidden_tri_att no_blocks=len(target_module.blocks),
c_hidden_tri_mul = ori_block.c_hidden_tri_mul clear_cache_between_blocks=target_module.clear_cache_between_blocks,
no_heads = ori_block.no_heads is_multimer=target_module.blocks[0].is_multimer,
pair_transition_n = ori_block.pair_transition_n
dropout_rate = ori_block.dropout_rate
inf = ori_block.inf
fastfold_block = TemplatePairStackBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
first_block=(block_id == 0),
last_block=(block_id == len(target_module) - 1),
) )
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_extra_msa_para(fast_block, target_block)
if target_block.training == False:
fast_block.eval()
model.extra_msa_stack = fast_module
copy_template_pair_stack_para(fastfold_block, ori_block)
if ori_block.training == False:
fastfold_block.eval()
fastfold_blocks.append(fastfold_block)
model.template_embedder.template_pair_stack.blocks = fastfold_blocks def inject_template(model):
with torch.no_grad():
if model.evoformer.blocks[0].is_multimer:
target_module = model.template_embedder
fast_module = TemplateEmbedderMultimer(config=model.template_embedder.config)
copy_template_multimer_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
else:
target_module = model.template_embedder
fast_module = TemplateEmbedder(config=model.template_embedder.config)
copy_template_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
def inject_embedder(model):
if model.evoformer.blocks[0].is_multimer:
return
# recycle embedder
with torch.no_grad():
target_module = model.recycling_embedder
fast_module = RecyclingEmbedder(
c_m=target_module.c_m,
c_z=target_module.c_z,
min_bin=target_module.min_bin,
max_bin=target_module.max_bin,
no_bins=target_module.no_bins,
inf=target_module.inf
)
copy_native_linear(fast_module.linear, target_module.linear)
copy_layernorm(fast_module.layer_norm_m, target_module.layer_norm_m)
copy_layernorm(fast_module.layer_norm_z, target_module.layer_norm_z)
if target_module.training == False:
fast_module.eval()
model.recycling_embedder = fast_module
# input embedder
with torch.no_grad():
target_module = model.input_embedder
fast_module = InputEmbedder(
tf_dim=target_module.tf_dim,
msa_dim=target_module.msa_dim,
c_z=target_module.c_z,
c_m=target_module.c_m,
relpos_k=target_module.relpos_k,
)
copy_linear(fast_module.linear_tf_z_i, target_module.linear_tf_z_i)
copy_linear(fast_module.linear_tf_z_j, target_module.linear_tf_z_j)
copy_linear(fast_module.linear_tf_m, target_module.linear_tf_m)
copy_linear(fast_module.linear_msa_m, target_module.linear_msa_m)
copy_linear(fast_module.linear_relpos, target_module.linear_relpos)
if target_module.training == False:
fast_module.eval()
model.input_embedder = fast_module
def inject_fastnn(model): def inject_fastnn(model):
inject_evoformer(model) inject_evoformer(model)
inject_extraMsaBlock(model) inject_extramsa(model)
inject_templatePairBlock(model) inject_template(model)
inject_embedder(model)
return model return model
\ No newline at end of file
import os
def get_param_path():
# develop
if os.path.exists('/data/scratch/alphafold/alphafold/params/params_model_1.npz'):
return '/data/scratch/alphafold/alphafold/params/params_model_1.npz'
# test
return '/data/scratch/fastfold/weight.npz'
def get_data_path():
# develop
if os.path.exists('/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'):
return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
# test
return '/data/scratch/fastfold/mono_batch.pkl'
...@@ -38,7 +38,7 @@ from fastfold.data import data_pipeline, feature_pipeline, templates ...@@ -38,7 +38,7 @@ from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
from fastfold.utils import inject_fastnn from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.parsers import parse_fasta from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
......
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_module_and_output():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer
fast_module_1 = fast_module.blocks[0].eval().cuda()
fast_module_2 = fast_module.blocks[-1].eval().cuda()
target_module = target_module.evoformer
target_module_1 = target_module.blocks[0].eval().cuda()
target_module_2 = target_module.blocks[-1].eval().cuda()
msa_len = 80
seq_len = 80
m = torch.randn((msa_len, seq_len, 256))
m_mask = torch.ones((msa_len, seq_len))
z = torch.randn((seq_len, seq_len, 128))
z_mask = torch.ones((seq_len, seq_len))
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
m_out, z_out = target_module_1(*inputs)
m_out, z_out = target_module_2(m_out, z_out, inputs[2], inputs[3])
return fast_module_1, fast_module_2, m_out, z_out, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_module_and_output):
run_func = partial(_test_evoformer, world_size=world_size, chunk_size=chunk_size, inplace=inplace, get_module_and_output=get_module_and_output)
mp.spawn(run_func, nprocs=world_size)
def _test_evoformer(rank, world_size, chunk_size, inplace, get_module_and_output):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module_1, fast_module_2, m_out, z_out, data = get_module_and_output
fast_module_1 = copy.deepcopy(fast_module_1).eval().cuda()
fast_module_2 = copy.deepcopy(fast_module_2).eval().cuda()
inputs = [copy.deepcopy(i).cuda() for i in data]
set_chunk_size(chunk_size)
with torch.no_grad():
if not inplace:
m_fast, z_fast = fast_module_1(*inputs)
m_fast, z_fast = fast_module_2(m_fast, z_fast, inputs[2], inputs[3])
else:
m_fast, z_fast = fast_module_1.inplace([inputs[0]], [inputs[1]], inputs[2], inputs[3])
m_fast, z_fast = fast_module_2.inplace(m_fast, z_fast, inputs[2], inputs[3])
m_fast = m_fast[0]
z_fast = z_fast[0]
error = torch.mean(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-4, f"Test m failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 5e-4, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_module_and_output():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
model = AlphaFold(config)
import_jax_weights_(model, get_param_path())
fast_model = copy.deepcopy(model)
fast_model = inject_fastnn(fast_model)
fast_model = fast_model.evoformer
fast_model.eval().cuda()
model = model.evoformer
model.eval().cuda()
msa_len = 50
seq_len = 52
m = torch.randn((msa_len, seq_len, 256))
m_mask = torch.ones((msa_len, seq_len)).to(dtype=m.dtype)
z = torch.randn((seq_len, seq_len, 128))
z_mask = torch.ones((seq_len, seq_len)).to(dtype=z.dtype)
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
out = model(
*inputs, chunk_size=None, _mask_trans=config.model._mask_trans)
return fast_model, config, out, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 1])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_module_and_output):
run_func = partial(_test_evoformer_stack, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_module_and_output=get_module_and_output)
mp.spawn(run_func, nprocs=world_size)
def _test_evoformer_stack(rank, world_size, chunk_size, inplace, get_module_and_output):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module, config, out, data = get_module_and_output
inputs = [copy.deepcopy(i).cuda() for i in data]
fast_module = copy.deepcopy(fast_module).eval().cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if not inplace:
m_fast, z_fast, s_fast = fast_module(
*inputs, chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
else:
m_fast, z_fast, s_fast = fast_module.inplace(
[inputs[0]], [inputs[1]], inputs[2], inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
m_fast = m_fast[0]
z_fast = z_fast[0]
error = torch.mean(torch.abs(out[0].cuda() - m_fast))
assert error < 2e-3, f"Test m failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(out[1].cuda() - z_fast))
assert error < 2e-3, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
error = torch.mean(torch.abs(out[2].cuda() - s_fast))
assert error < 2e-3, f"Test s failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.extra_msa_stack
fast_module = fast_module.cuda().eval()
extra_msa_len = 300
seq_len = 64
m = torch.randn((extra_msa_len, seq_len, 64)).cuda()
m_mask = torch.ones((extra_msa_len, seq_len)).cuda().to(dtype=m.dtype)
m_mask[64:, :] = 0.
z = torch.randn((seq_len, seq_len, 128)).cuda()
z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
data = [m, z, m_mask, z_mask]
inputs = [copy.deepcopy(i).cuda() for i in data]
target_module = target_module.extra_msa_stack
target_module = target_module.eval().cuda()
z_out = target_module(
inputs[0], inputs[1], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=None, _mask_trans=config.model._mask_trans)
return z_out, config, fast_module, data
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_extramsa_stack, world_size=world_size, chunk_size=chunk_size, inplace=inplace,
get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_extramsa_stack(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
z_out, config, fast_module, data = get_openfold_module_and_data
inputs = [copy.deepcopy(i).cuda() for i in data]
fast_module = copy.deepcopy(fast_module).eval().cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if not inplace:
z_fast = fast_module(
inputs[0], inputs[1], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
else:
z_fast = fast_module.inplace(
[inputs[0]], [inputs[1]], msa_mask=inputs[2], pair_mask=inputs[3], chunk_size=chunk_size, _mask_trans=config.model._mask_trans)
z_fast = z_fast[0]
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 1e-3, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment