"...triton-llm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "dddebc0df2a6952a508cd1a127c7aff0bc60d934"
Commit 30764cf9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor fixes/reformatting for recent multimer training PR

parent 31051cf2
......@@ -160,10 +160,10 @@ def model_config(
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20
for k,v in multimer_model_config_update['model'].items():
for k, v in multimer_model_config_update['model'].items():
c.model[k] = v
for k,v in multimer_model_config_update['loss'].items():
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
......@@ -683,24 +683,11 @@ config = mlc.ConfigDict(
)
multimer_model_config_update = {
'model':{"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
'model': {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
......@@ -841,8 +828,6 @@ multimer_model_config_update = {
},
"recycle_early_stop_tolerance": 0.5
},
"recycle_early_stop_tolerance": 0.5
},
"loss": {
"distogram": {
"min_bin": 2.3125,
......@@ -863,7 +848,7 @@ multimer_model_config_update = {
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"interface": {
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5,
......@@ -918,5 +903,5 @@ multimer_model_config_update = {
"enabled": True,
},
"eps": eps,
},
}
}
......@@ -7,7 +7,6 @@ import pickle
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
......@@ -18,7 +17,7 @@ from openfold.data import (
mmcif_parsing,
templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
from openfold.utils.tensor_utils import dict_multimap
import contextlib
import tempfile
from openfold.utils.tensor_utils import (
......@@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str):
fasta_file.seek(0)
yield fasta_file.name
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
......@@ -296,6 +296,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self._chain_ids)
class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
......@@ -549,10 +550,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
device=all_chain_features["aatype"].device)
return all_chain_features
def __len__(self):
return len(self._chain_ids)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
......@@ -575,6 +576,7 @@ def deterministic_train_filter(
return True
def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution:float= 9.,
......@@ -613,9 +615,10 @@ def deterministic_multimer_train_filter(
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> List[float]:
) -> float:
# Stochastic filters
probabilities = []
......@@ -723,8 +726,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
......@@ -753,7 +756,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5):
max_resolution=9,
minimum_number_of_residues=5):
selected_idx.append(i)
return selected_idx
......@@ -781,11 +785,13 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
......@@ -873,6 +879,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
......@@ -1110,7 +1117,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def predict_dataloader(self):
return self._gen_dataloader("predict")
class OpenFoldMultimerDataModule(OpenFoldDataModule):
"""
Create a datamodule specifically for multimer training
......
......@@ -784,45 +784,6 @@ class DataPipeline:
return all_hits
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
input_sequence=None,
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
elif (ext =='.sto') and (f.startswith("hmm")):
with open(path,"r") as fp:
hits = parsers.parse_hmmsearch_sto(fp.read(),input_sequence)
all_hits[f] = hits
fp.close()
return all_hits
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
......@@ -879,9 +840,9 @@ class DataPipeline:
num_res = len(input_sequence)
hits = self._parse_template_hit_files(
alignment_dir,
input_sequence,
alignment_index,
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features(
......@@ -928,8 +889,9 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence)
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index)
template_features = make_template_features(
input_sequence,
......@@ -976,8 +938,9 @@ class DataPipeline:
)
hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features(
......@@ -1008,8 +971,9 @@ class DataPipeline:
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features(
......@@ -1098,7 +1062,10 @@ class DataPipeline:
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None,input_sequence=input_sequence)
hits = self._parse_template_hit_files(alignment_dir=alignment_dir,
input_sequence=seq,
alignment_index=None)
template_features = make_template_features(
seq,
hits,
......
......@@ -310,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface},
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface.weight)
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss(
traj=traj,
......@@ -541,8 +541,11 @@ def lddt_loss(
cutoff=cutoff,
eps=eps
)
score = torch.nan_to_num(score,nan=torch.nanmean(score))
# TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score<0] = 0
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
......@@ -1233,7 +1236,7 @@ def find_structural_violations(
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
torch.cuda.memory_summary()
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
......@@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
chains = asym_id.unique()
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot = torch.nn.functional.one_hot(asym_id.long() - 1,
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
......@@ -1688,6 +1693,7 @@ def chain_center_of_mass_loss(
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss
# #
# below are the functions required for permutations
# #
......@@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q):
assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda')
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
......
......@@ -51,7 +51,8 @@ def get_alphafold_config():
return config
_param_path = f"openfold/resources/params/params_{consts.model}.npz"
dir_path = os.path.dirname(os.path.realpath(__file__))
_param_path = os.path.join(dir_path, "..", f"openfold/resources/params/params_{consts.model}.npz")
_model = None
......
......@@ -256,7 +256,7 @@ class Template(unittest.TestCase):
template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro = model.template_embedder(
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
......@@ -267,7 +267,7 @@ class Template(unittest.TestCase):
inplace_safe=False
)
else:
out_repro = model.template_embedder(
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
......@@ -277,10 +277,10 @@ class Template(unittest.TestCase):
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment