Commit fdcb72e8 authored by Christina Floristean's avatar Christina Floristean
Browse files

Bug fixes for multimer inference and monomer training

parent 51556d52
...@@ -155,14 +155,17 @@ def model_config( ...@@ -155,14 +155,17 @@ def model_config(
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.globals.is_multimer = True
c.globals.bfloat16 = True c.globals.bfloat16 = False
c.globals.bfloat16_output = False c.globals.bfloat16_output = False
c.loss.masked_msa.num_classes = 22 c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20 c.data.common.max_recycling_iters = 20
for k,v in multimer_model_config_update.items(): for k,v in multimer_model_config_update['model'].items():
c.model[k] = v c.model[k] = v
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name): if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252 #c.model.input_embedder.num_msa = 252
...@@ -590,6 +593,12 @@ config = mlc.ConfigDict( ...@@ -590,6 +593,12 @@ config = mlc.ConfigDict(
"c_out": 37, "c_out": 37,
}, },
}, },
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1.
}, },
"relax": { "relax": {
"max_iterations": 0, # no max "max_iterations": 0, # no max
...@@ -670,157 +679,154 @@ config = mlc.ConfigDict( ...@@ -670,157 +679,154 @@ config = mlc.ConfigDict(
"eps": eps, "eps": eps,
}, },
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1
} }
) )
multimer_model_config_update = { multimer_model_config_update = {
"input_embedder": { "model": {
"tf_dim": 21, "input_embedder": {
"msa_dim": 49, "tf_dim": 21,
#"num_msa": 508, "msa_dim": 49,
"c_z": c_z, #"num_msa": 508,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z, "c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m, "c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
}, },
"template_pair_stack": { "template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t, "c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement "c_z": c_z,
# as 64. In the code, it's 16. "inf": 1e5, # 1e9,
"c_hidden_tri_att": 16, "eps": eps, # 1e-6,
"c_hidden_tri_mul": 64, "enabled": templates_enabled,
"no_blocks": 2, "embed_angles": embed_template_torsion_angles,
"no_heads": 4, "use_unit_vector": True
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
}, },
"c_t": c_t, "extra_msa": {
"c_z": c_z, "extra_msa_embedder": {
"inf": 1e5, # 1e9, "c_in": 25,
"eps": eps, # 1e-6, "c_out": c_e,
"enabled": templates_enabled, #"num_extra_msa": 2048
"embed_angles": embed_template_torsion_angles, },
"use_unit_vector": True "extra_msa_stack": {
}, "c_m": c_e,
"extra_msa": { "c_z": c_z,
"extra_msa_embedder": { "c_hidden_msa_att": 8,
"c_in": 25, "c_hidden_opm": 32,
"c_out": c_e, "c_hidden_mul": 128,
#"num_extra_msa": 2048 "c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
}, },
"extra_msa_stack": { "evoformer_stack": {
"c_m": c_e, "c_m": c_m,
"c_z": c_z, "c_z": c_z,
"c_hidden_msa_att": 8, "c_hidden_msa_att": 32,
"c_hidden_opm": 32, "c_hidden_opm": 32,
"c_hidden_mul": 128, "c_hidden_mul": 128,
"c_hidden_pair_att": 32, "c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8, "no_heads_msa": 8,
"no_heads_pair": 4, "no_heads_pair": 4,
"no_blocks": 4, "no_blocks": 48,
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": True, "opm_first": True,
"fuse_projection_weights": True, "fuse_projection_weights": True,
"clear_cache_between_blocks": True, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
}, },
"tm": { "structure_module": {
"c_s": c_s,
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "c_ipa": 16,
"ptm_weight": 0.2, "c_resnet": 128,
"iptm_weight": 0.8, "no_heads_ipa": 12,
"enabled": True, "no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
}, },
"masked_msa": { "heads": {
"c_m": c_m, "lddt": {
"c_out": 22, "no_bins": 50,
}, "c_in": c_s,
"experimentally_resolved": { "c_hidden": 128,
"c_s": c_s, },
"c_out": 37, "distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
}, },
"recycle_early_stop_tolerance": 0.5
}, },
"loss": { "loss": {
"distogram": { "distogram": {
...@@ -897,6 +903,5 @@ multimer_model_config_update = { ...@@ -897,6 +903,5 @@ multimer_model_config_update = {
"enabled": True, "enabled": True,
}, },
"eps": eps, "eps": eps,
}, }
"recycle_early_stop_tolerance": 0.5
} }
...@@ -151,7 +151,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -151,7 +151,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
} }
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date, max_template_date=max_template_date,
max_hits=max_template_hits, max_hits=max_template_hits,
......
...@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union ...@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray] ...@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
hits: Sequence[Any], hits: Sequence[Any],
template_featurizer: Any, template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None): if(len(hits_cat) == 0 or template_featurizer is None):
...@@ -61,11 +49,6 @@ def make_template_features( ...@@ -61,11 +49,6 @@ def make_template_features(
) )
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features return template_features
...@@ -453,7 +436,8 @@ class AlignmentRunner: ...@@ -453,7 +436,8 @@ class AlignmentRunner:
if(uniprot_database_path is not None): if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path database_path=uniprot_database_path,
n_cpu=no_cpus
) )
if(template_searcher is not None and if(template_searcher is not None and
...@@ -800,37 +784,6 @@ class DataPipeline: ...@@ -800,37 +784,6 @@ class DataPipeline:
return all_hits return all_hits
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
...@@ -935,15 +888,15 @@ class DataPipeline: ...@@ -935,15 +888,15 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence,
alignment_index) alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer
query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
...@@ -984,8 +937,9 @@ class DataPipeline: ...@@ -984,8 +937,9 @@ class DataPipeline:
is_distillation=is_distillation is_distillation=is_distillation
) )
hits = self._parse_template_hits( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence,
alignment_index alignment_index
) )
...@@ -1016,8 +970,9 @@ class DataPipeline: ...@@ -1016,8 +970,9 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence,
alignment_index alignment_index
) )
...@@ -1107,7 +1062,7 @@ class DataPipeline: ...@@ -1107,7 +1062,7 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hits(alignment_dir, alignment_index=None) hits = self._parse_template_hit_files(alignment_dir, seq, alignment_index=None)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
......
...@@ -89,18 +89,17 @@ def make_all_atom_aatype(protein): ...@@ -89,18 +89,17 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein): def fix_templates_aatype(protein):
# Map one-hot to indices # Map one-hot to indices
num_templates = protein["template_aatype"].shape[0] num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0): protein["template_aatype"] = torch.argmax(
protein["template_aatype"] = torch.argmax( protein["template_aatype"], dim=-1
protein["template_aatype"], dim=-1 )
) # Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype. new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order = torch.tensor(
new_order = torch.tensor( new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device, ).expand(num_templates, -1)
).expand(num_templates, -1) protein["template_aatype"] = torch.gather(
protein["template_aatype"] = torch.gather( new_order, 1, index=protein["template_aatype"]
new_order, 1, index=protein["template_aatype"] )
)
return protein return protein
......
...@@ -2,7 +2,9 @@ from typing import Sequence ...@@ -2,7 +2,9 @@ from typing import Sequence
import torch import torch
from openfold.config import NUM_RES
from openfold.data.data_transforms import curry1 from openfold.data.data_transforms import curry1
from openfold.np import residue_constants as rc
from openfold.utils.tensor_utils import masked_mean from openfold.utils.tensor_utils import masked_mean
...@@ -301,3 +303,177 @@ def make_msa_profile(batch): ...@@ -301,3 +303,177 @@ def make_msa_profile(batch):
) )
return batch return batch
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = diff_chain_mask[..., None] * pair_mask
min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
positions = protein["all_atom_positions"]
atom_mask = protein["all_atom_mask"]
asym_id = protein["asym_id"]
interface_residues = get_interface_residues(positions=positions,
atom_mask=atom_mask,
asym_id=asym_id,
interface_threshold=interface_threshold)
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,),
device=positions.device, generator=generator)[0])]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
ca_mask = atom_mask[..., ca_idx].bool()
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
to_target_distances = ca_pairwise_dists[target_res]
break_tie = (
torch.arange(
0, to_target_distances.shape[-1], device=positions.device
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_contiguous_crop_idx(protein, crop_size, generator):
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
return torch.arange(num_res)
_, chain_lens = protein["asym_id"].unique(return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
num_budget = crop_size
crop_idxs = []
asym_offset = torch.tensor(0, dtype=torch.int64)
for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
# num res at most we can keep in this ent
crop_size_max = min(num_budget, this_len)
# num res at least we shall keep in this ent
crop_size_min = min(this_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max + 1,
generator=generator,
device=chain_lens.device)
chain_start = randint(lower=0,
upper=this_len - chain_crop_size + 1,
generator=generator,
device=chain_lens.device)
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
asym_offset += this_len
num_budget -= chain_crop_size
return torch.concat(crop_idxs)
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
spatial_crop_prob,
interface_threshold,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
if use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates + 1,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
protein[k] = v
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
...@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
) )
) )
......
...@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = { ...@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = {
} }
def empty_template_feats(n_res):
return {
"template_aatype": np.zeros(
(0, n_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_mask": np.zeros(
(0, n_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit.""" """Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
...@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
else: else:
num_res = len(query_sequence) num_res = len(query_sequence)
# Construct a default template with all zeros. # Construct a default template with all zeros.
template_features = { template_features = empty_template_feats(num_res)
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings features=template_features, errors=errors, warnings=warnings
...@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
else: else:
num_res = len(query_sequence) num_res = len(query_sequence)
# Construct a default template with all zeros. # Construct a default template with all zeros.
template_features = { template_features = empty_template_feats(num_res)
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, features=template_features,
......
...@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module): ...@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module):
entity_id = batch["entity_id"] entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :]) entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None]) rel_feats.append(entity_id_same[..., None].to(dtype=rel_pos.dtype))
sym_id = batch["sym_id"] sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :] rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
...@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module): ...@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module):
# a second copy during the stack later on # a second copy during the stack later on
t_pair = z.new_zeros( t_pair = z.new_zeros(
z.shape[:-3] + z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t) (n_templ, n, n, self.config.template_pair_embedder.c_out)
) )
for i in range(n_templ): for i in range(n_templ):
...@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module): ...@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module):
): ):
super(TemplatePairEmbedderMultimer, self).__init__() super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out) self.dgram_linear = Linear(c_dgram, c_out, init='relu')
self.aatype_linear_1 = Linear(c_aatype, c_out) self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
self.aatype_linear_2 = Linear(c_aatype, c_out) self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_z) self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out) self.query_embedding_linear = Linear(c_z, c_out, init='relu')
self.pseudo_beta_mask_linear = Linear(1, c_out) self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
self.x_linear = Linear(1, c_out) self.x_linear = Linear(1, c_out, init='relu')
self.y_linear = Linear(1, c_out) self.y_linear = Linear(1, c_out, init='relu')
self.z_linear = Linear(1, c_out) self.z_linear = Linear(1, c_out, init='relu')
self.backbone_mask_linear = Linear(1, c_out) self.backbone_mask_linear = Linear(1, c_out, init='relu')
def forward(self, def forward(self,
template_dgram: torch.Tensor, template_dgram: torch.Tensor,
...@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module):
single_template_embeds = {} single_template_embeds = {}
act = 0. act = 0.
template_positions, pseudo_beta_mask = ( template_positions, pseudo_beta_mask = pseudo_beta_fn(
single_template_feats["template_pseudo_beta"], single_template_feats["template_aatype"],
single_template_feats["template_pseudo_beta_mask"], single_template_feats["template_all_atom_positions"],
) single_template_feats["template_all_atom_mask"])
template_dgram = dgram_from_positions( template_dgram = dgram_from_positions(
template_positions, template_positions,
......
...@@ -186,11 +186,6 @@ class AlphaFold(nn.Module): ...@@ -186,11 +186,6 @@ class AlphaFold(nn.Module):
if self.config.recycle_early_stop_tolerance < 0: if self.config.recycle_early_stop_tolerance < 0:
return False return False
if no_batch_dims == 0:
prev_pos = prev_pos.unsqueeze(dim=0)
next_pos = next_pos.unsqueeze(dim=0)
mask = mask.unsqueeze(dim=0)
ca_idx = residue_constants.atom_order['CA'] ca_idx = residue_constants.atom_order['CA']
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2 sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :] mask = mask[..., None] * mask[..., None, :]
...@@ -265,7 +260,7 @@ class AlphaFold(nn.Module): ...@@ -265,7 +260,7 @@ class AlphaFold(nn.Module):
requires_grad=False, requires_grad=False,
) )
x_prev = pseudo_beta_fn( pseudo_beta_x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None feats["aatype"], x_prev, None
).to(dtype=z.dtype) ).to(dtype=z.dtype)
...@@ -279,10 +274,12 @@ class AlphaFold(nn.Module): ...@@ -279,10 +274,12 @@ class AlphaFold(nn.Module):
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, pseudo_beta_x_prev,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
del pseudo_beta_x_prev
if(self.globals.offload_inference and inplace_safe): if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device) m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device) z = z.to(z_prev.device)
......
...@@ -166,12 +166,14 @@ class PointProjection(nn.Module): ...@@ -166,12 +166,14 @@ class PointProjection(nn.Module):
c_hidden: int, c_hidden: int,
num_points: int, num_points: int,
no_heads: int, no_heads: int,
is_multimer: bool,
return_local_points: bool = False, return_local_points: bool = False,
): ):
super().__init__() super().__init__()
self.return_local_points = return_local_points self.return_local_points = return_local_points
self.no_heads = no_heads self.no_heads = no_heads
self.num_points = num_points self.num_points = num_points
self.is_multimer = is_multimer
self.linear = Linear(c_hidden, no_heads * 3 * num_points) self.linear = Linear(c_hidden, no_heads * 3 * num_points)
...@@ -181,24 +183,19 @@ class PointProjection(nn.Module): ...@@ -181,24 +183,19 @@ class PointProjection(nn.Module):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training # TODO: Needs to run in high precision during training
points_local = self.linear(activations) points_local = self.linear(activations)
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
if isinstance(rigids, Rigid3Array): if self.is_multimer:
points_local = points_local.reshape( points_local = points_local.view(
*points_local.shape[:-1], points_local.shape[:-1] + (self.no_heads, -1)
self.no_heads,
-1,
) )
points_local = torch.split( points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1 points_local, points_local.shape[-1] // 3, dim=-1
) )
points_local = torch.stack(points_local, dim=-1) points_local = torch.stack(points_local, dim=-1).view(out_shape)
if not isinstance(rigids, Rigid3Array):
points_local = points_local.reshape(
*points_local.shape[:-2], self.no_heads, -1, 3
)
points_global = rigids[..., None, None].apply(points_local) points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points): if(self.return_local_points):
...@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module): ...@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module):
self.linear_q_points = PointProjection( self.linear_q_points = PointProjection(
self.c_s, self.c_s,
self.no_qk_points, self.no_qk_points,
self.no_heads self.no_heads,
self.is_multimer
) )
if(is_multimer): if(is_multimer):
...@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module): ...@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module):
self.c_s, self.c_s,
self.no_qk_points, self.no_qk_points,
self.no_heads, self.no_heads,
self.is_multimer
) )
self.linear_v_points = PointProjection( self.linear_v_points = PointProjection(
self.c_s, self.c_s,
self.no_v_points, self.no_v_points,
self.no_heads, self.no_heads,
self.is_multimer
) )
else: else:
self.linear_kv = Linear(self.c_s, 2 * hc) self.linear_kv = Linear(self.c_s, 2 * hc)
...@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module): ...@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module):
self.c_s, self.c_s,
self.no_qk_points + self.no_v_points, self.no_qk_points + self.no_v_points,
self.no_heads, self.no_heads,
self.is_multimer
) )
self.linear_b = Linear(self.c_z, self.no_heads) self.linear_b = Linear(self.c_z, self.no_heads)
...@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module): ...@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module):
return s return s
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
# whether or not the increase in test error matters in practice.
class InvariantPointAttentionMultimer(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
no_heads: int,
no_qk_points: int,
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
is_multimer: bool = True,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super(InvariantPointAttentionMultimer, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.no_qk_points = no_qk_points
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc, bias=False)
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
is_multimer=True
)
self.linear_b = Linear(self.c_z, self.no_heads)
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.no_heads * (
self.c_z + self.c_hidden + self.no_v_points * 4
)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-2)
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
a = 0.
point_variance = (max(self.no_qk_points, 1) * 9.0 / 2)
point_weights = math.sqrt(1.0 / point_variance)
softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x))
head_weights = softplus(self.head_weights)
point_weights = point_weights * head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts = Vec3Array.from_array(self.linear_q_points(s, r))
# [*, N_res, H, P_qk, 3]
k_pts = Vec3Array.from_array(self.linear_k_points(s, r))
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
a = a + pt_att
scalar_variance = max(self.c_hidden, 1) * 1.
scalar_weights = math.sqrt(1.0 / scalar_variance)
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
k = self.linear_k(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
q = q * scalar_weights
a = a + torch.einsum('...qhc,...khc->...qkh', q, k)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if (_offload_inference):
assert (sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
a = a + b
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
a = a + square_mask.unsqueeze(-1)
a = a * math.sqrt(1. / 3) # Normalize by number of logit terms (3)
a = self.softmax(a)
# [*, N_res, H * C_hidden]
v = self.linear_v(s)
# [*, N_res, H, C_hidden]
v = v.view(v.shape[:-1] + (self.no_heads, -1))
o = torch.einsum('...qkh, ...khc->...qhc', a, v)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, N_res, H, P_v, 3]
v_pts = Vec3Array.from_array(self.linear_v_points(s, r))
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *o_pt_flat, o_pt_norm, o_pair), dim=-1
).to(dtype=z[0].dtype)
)
return s
class BackboneUpdate(nn.Module): class BackboneUpdate(nn.Module):
""" """
Implements part of Algorithm 23. Implements part of Algorithm 23.
...@@ -670,7 +895,8 @@ class StructureModule(nn.Module): ...@@ -670,7 +895,8 @@ class StructureModule(nn.Module):
self.linear_in = Linear(self.c_s, self.c_s) self.linear_in = Linear(self.c_s, self.c_s)
self.ipa = InvariantPointAttention( ipa = InvariantPointAttention if not self.is_multimer else InvariantPointAttentionMultimer
self.ipa = ipa(
self.c_s, self.c_s,
self.c_z, self.c_z,
self.c_ipa, self.c_ipa,
......
...@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): ...@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
def compute_projection(pair, mask): def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask) p = compute_projection_helper(pair, mask)
if self._outgoing: left = p[..., :self.c_hidden]
left = p[..., :self.c_hidden] right = p[..., self.c_hidden:]
right = p[..., self.c_hidden:]
else:
left = p[..., self.c_hidden:]
right = p[..., :self.c_hidden]
return left, right return left, right
...@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): ...@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
ab = ab * self.sigmoid(self.linear_ab_g(z)) ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z) ab = ab * self.linear_ab_p(z)
if self._outgoing: a = ab[..., :self.c_hidden]
a = ab[..., :self.c_hidden] b = ab[..., self.c_hidden:]
b = ab[..., self.c_hidden:]
else:
b = ab[..., :self.c_hidden]
a = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in # Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes # reduced-precision modes
......
...@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype): ...@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype):
def atom14_to_atom37( def atom14_to_atom37(
atom14_data: torch.Tensor, # (*, N, 14, ...) atom14_data: torch.Tensor, # (*, N, 14, ...)
aatype: torch.Tensor # (*, N) aatype: torch.Tensor # (*, N)
) -> torch.Tensor: # (*, N, 37, ...) ) -> Tuple: # (*, N, 37, ...)
"""Convert atom14 to atom37 representation.""" """Convert atom14 to atom37 representation."""
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype) idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long()
no_batch_dims = len(aatype.shape) - 1 no_batch_dims = len(aatype.shape) - 1
atom37_data = tensor_utils.batched_gather( atom37_data = tensor_utils.batched_gather(
atom14_data, atom14_data,
...@@ -50,10 +50,10 @@ def atom14_to_atom37( ...@@ -50,10 +50,10 @@ def atom14_to_atom37(
if len(atom14_data.shape) == no_batch_dims + 2: if len(atom14_data.shape) == no_batch_dims + 2:
atom37_data *= atom37_mask atom37_data *= atom37_mask
elif len(atom14_data.shape) == no_batch_dims + 3: elif len(atom14_data.shape) == no_batch_dims + 3:
atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype) atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype)
else: else:
raise ValueError("Incorrectly shaped data") raise ValueError("Incorrectly shaped data")
return atom37_data return atom37_data, atom37_mask
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
...@@ -230,13 +230,13 @@ def torsion_angles_to_frames( ...@@ -230,13 +230,13 @@ def torsion_angles_to_frames(
num_residues = aatype.shape[-1] num_residues = aatype.shape[-1]
sin_angles = torch.cat( sin_angles = torch.cat(
[ [
torch.zeros_like(aatype).unsqueeze(), torch.zeros_like(aatype).unsqueeze(dim=-1),
sin_angles, sin_angles,
], ],
dim=-1) dim=-1)
cos_angles = torch.cat( cos_angles = torch.cat(
[ [
torch.ones_like(aatype).unsqueeze(), torch.ones_like(aatype).unsqueeze(dim=-1),
cos_angles cos_angles
], ],
dim=-1 dim=-1
......
...@@ -20,7 +20,7 @@ class QuatRigid(nn.Module): ...@@ -20,7 +20,7 @@ class QuatRigid(nn.Module):
def forward(self, activations: torch.Tensor) -> Rigid3Array: def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision # NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32)) rigid_flat = self.linear(activations)
rigid_flat = torch.unbind(rigid_flat, dim=-1) rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat): if(self.full_quat):
......
...@@ -172,20 +172,20 @@ class Rot3Array: ...@@ -172,20 +172,20 @@ class Rot3Array:
) -> Rot3Array: ) -> Rot3Array:
"""Construct Rot3Array from components of quaternion.""" """Construct Rot3Array from components of quaternion."""
if normalize: if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
w = w * inv_norm w = w * inv_norm
x = x * inv_norm x = x * inv_norm
y = y * inv_norm y = y * inv_norm
z = z * inv_norm z = z * inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2) xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z) xy = 2.0 * (x * y - w * z)
xz = 2 * (x * z + w * y) xz = 2.0 * (x * z + w * y)
yx = 2 * (x * y + w * z) yx = 2.0 * (x * y + w * z)
yy = 1 - 2 * (x ** 2 + z ** 2) yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
yz = 2 * (y * z - w * x) yz = 2.0 * (y * z - w * x)
zx = 2 * (x * z - w * y) zx = 2.0 * (x * z - w * y)
zy = 2 * (y * z + w * x) zy = 2.0 * (y * z + w * x)
zz = 1 - 2 * (x ** 2 + y ** 2) zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
def reshape(self, new_shape): def reshape(self, new_shape):
......
...@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" ...@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style) # With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum): class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2) lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else w.transpose(-1, -2)
) )
LinearWeightMHA = partial( LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
...@@ -58,6 +58,7 @@ class Param: ...@@ -58,6 +58,7 @@ class Param:
param: Union[torch.Tensor, List[torch.Tensor]] param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other param_type: ParamType = ParamType.Other
stacked: bool = False stacked: bool = False
swap: bool = False
def process_translation_dict(d, top_layer=True): def process_translation_dict(d, top_layer=True):
...@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None): ...@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None):
param=[param.param for param in v], param=[param.param for param in v],
param_type=v[0].param_type, param_type=v[0].param_type,
stacked=True, stacked=True,
swap=v[0].swap
) )
out[k] = stacked_param out[k] = stacked_param
...@@ -122,7 +124,12 @@ def assign(translation_dict, orig_weights): ...@@ -122,7 +124,12 @@ def assign(translation_dict, orig_weights):
try: try:
weights = list(map(param_type.transformation, weights)) weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights): for p, w in zip(ref, weights):
p.copy_(w) if param.swap:
index = p.shape[0] // 2
p[:index].copy_(w[index:])
p[index:].copy_(w[:index])
else:
p.copy_(w)
except: except:
print(k) print(k)
print(ref[0].shape) print(ref[0].shape)
...@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False):
LinearBiasMultimer = lambda l: ( LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer) Param(l, param_type=ParamType.LinearBiasMultimer)
) )
LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
LinearBiasSwap = lambda l: (Param(l, swap=True))
LinearParams = lambda l: { LinearParams = lambda l: {
"weights": LinearWeight(l.weight), "weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias), "bias": LinearBias(l.bias),
} }
LinearParamsMHA = lambda l: {
"weights": LinearWeightMHA(l.weight),
"bias": LinearBiasMHA(l.bias),
}
LinearParamsSwap = lambda l: {
"weights": LinearWeightSwap(l.weight),
"bias": LinearBiasSwap(l.bias),
}
LinearParamsMultimer = lambda l: { LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight), "weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias), "bias": LinearBiasMultimer(l.bias),
...@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False):
def TriMulOutParams(tri_mul, outgoing=True): def TriMulOutParams(tri_mul, outgoing=True):
if re.fullmatch("^model_[1-5]_multimer_v3$", version): if re.fullmatch("^model_[1-5]_multimer_v3$", version):
lin_param_type = LinearParams if outgoing else LinearParamsSwap
d = { d = {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in), "left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_ab_p), "projection": lin_param_type(tri_mul.linear_ab_p),
"gate": LinearParams(tri_mul.linear_ab_g), "gate": lin_param_type(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out), "center_norm": LayerNormParams(tri_mul.layer_norm_out),
} }
else: else:
...@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False):
} }
PointProjectionParams = lambda pp: { PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer( "point_projection": LinearParamsMHA(
pp.linear, pp.linear,
), ),
} }
IPAParamsMultimer = lambda ipa: { IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": { "q_scalar_projection": {
"weights": LinearWeightMultimer( "weights": LinearWeightMHA(
ipa.linear_q.weight, ipa.linear_q.weight,
), ),
}, },
"k_scalar_projection": { "k_scalar_projection": {
"weights": LinearWeightMultimer( "weights": LinearWeightMHA(
ipa.linear_k.weight, ipa.linear_k.weight,
), ),
}, },
"v_scalar_projection": { "v_scalar_projection": {
"weights": LinearWeightMultimer( "weights": LinearWeightMHA(
ipa.linear_v.weight, ipa.linear_v.weight,
), ),
}, },
...@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_0": LinearParams( "template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear temp_embedder.template_pair_embedder.dgram_linear
), ),
"template_pair_embedding_1": LinearParamsMultimer( "template_pair_embedding_1": LinearParams(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
), ),
"template_pair_embedding_2": LinearParams( "template_pair_embedding_2": LinearParams(
...@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_3": LinearParams( "template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2 temp_embedder.template_pair_embedder.aatype_linear_2
), ),
"template_pair_embedding_4": LinearParamsMultimer( "template_pair_embedding_4": LinearParams(
temp_embedder.template_pair_embedder.x_linear temp_embedder.template_pair_embedder.x_linear
), ),
"template_pair_embedding_5": LinearParamsMultimer( "template_pair_embedding_5": LinearParams(
temp_embedder.template_pair_embedder.y_linear temp_embedder.template_pair_embedder.y_linear
), ),
"template_pair_embedding_6": LinearParamsMultimer( "template_pair_embedding_6": LinearParams(
temp_embedder.template_pair_embedder.z_linear temp_embedder.template_pair_embedder.z_linear
), ),
"template_pair_embedding_7": LinearParamsMultimer( "template_pair_embedding_7": LinearParams(
temp_embedder.template_pair_embedder.backbone_mask_linear temp_embedder.template_pair_embedder.backbone_mask_linear
), ),
"template_pair_embedding_8": LinearParams( "template_pair_embedding_8": LinearParams(
...@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False):
), ),
"template_embedding_iteration": tps_blocks_params, "template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams( "output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm temp_embedder.template_pair_stack.layer_norm
), ),
}, },
"output_linear": LinearParams( "output_linear": LinearParams(
......
...@@ -1643,7 +1643,7 @@ def chain_center_of_mass_loss( ...@@ -1643,7 +1643,7 @@ def chain_center_of_mass_loss(
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True) chains = asym_id.unique()
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
......
...@@ -431,7 +431,7 @@ if __name__ == "__main__": ...@@ -431,7 +431,7 @@ if __name__ == "__main__":
help="""Postfix for output prediction filenames""" help="""Postfix for output prediction filenames"""
) )
parser.add_argument( parser.add_argument(
"--data_random_seed", type=str, default=None "--data_random_seed", type=int, default=None
) )
parser.add_argument( parser.add_argument(
"--skip_relaxation", action="store_true", default=False, "--skip_relaxation", action="store_true", default=False,
......
...@@ -45,7 +45,7 @@ def main(args): ...@@ -45,7 +45,7 @@ def main(args):
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path, mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniref30_database_path=args.uniref30_database_path,
small_bfd_database_path=None, small_bfd_database_path=None,
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
template_searcher=template_searcher, template_searcher=template_searcher,
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
from pathlib import Path
from tests.config import consts
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
...@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_ ...@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_
class TestImportWeights(unittest.TestCase): class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self): def test_import_jax_weights_(self):
npz_path = "openfold/resources/params/params_model_1_ptm.npz" npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz"
c = model_config("model_1_ptm") c = model_config(consts.model)
c.globals.blocks_per_ckpt = None c.globals.blocks_per_ckpt = None
model = AlphaFold(c) model = AlphaFold(c)
model.eval()
import_jax_weights_( import_jax_weights_(
model, model,
npz_path, npz_path,
version=consts.model
) )
data = np.load(npz_path) data = np.load(npz_path)
......
...@@ -22,7 +22,7 @@ from openfold.data.data_modules import ( ...@@ -22,7 +22,7 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment