"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "06eac9eaa70be4ca9aacd82a323134cfde3604b2"
Commit 30764cf9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor fixes/reformatting for recent multimer training PR

parent 31051cf2
...@@ -160,10 +160,10 @@ def model_config( ...@@ -160,10 +160,10 @@ def model_config(
c.loss.masked_msa.num_classes = 22 c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20 c.data.common.max_recycling_iters = 20
for k,v in multimer_model_config_update['model'].items(): for k, v in multimer_model_config_update['model'].items():
c.model[k] = v c.model[k] = v
for k,v in multimer_model_config_update['loss'].items(): for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
...@@ -683,24 +683,11 @@ config = mlc.ConfigDict( ...@@ -683,24 +683,11 @@ config = mlc.ConfigDict(
) )
multimer_model_config_update = { multimer_model_config_update = {
'model':{"input_embedder": { 'model': {
"tf_dim": 21, "input_embedder": {
"msa_dim": 49, "tf_dim": 21,
#"num_msa": 508, "msa_dim": 49,
"c_z": c_z, #"num_msa": 508,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"relpos_k": 32, "relpos_k": 32,
...@@ -841,8 +828,6 @@ multimer_model_config_update = { ...@@ -841,8 +828,6 @@ multimer_model_config_update = {
}, },
"recycle_early_stop_tolerance": 0.5 "recycle_early_stop_tolerance": 0.5
}, },
"recycle_early_stop_tolerance": 0.5
},
"loss": { "loss": {
"distogram": { "distogram": {
"min_bin": 2.3125, "min_bin": 2.3125,
...@@ -863,7 +848,7 @@ multimer_model_config_update = { ...@@ -863,7 +848,7 @@ multimer_model_config_update = {
"loss_unit_distance": 10.0, "loss_unit_distance": 10.0,
"weight": 0.5, "weight": 0.5,
}, },
"interface": { "interface_backbone": {
"clamp_distance": 30.0, "clamp_distance": 30.0,
"loss_unit_distance": 20.0, "loss_unit_distance": 20.0,
"weight": 0.5, "weight": 0.5,
...@@ -918,5 +903,5 @@ multimer_model_config_update = { ...@@ -918,5 +903,5 @@ multimer_model_config_update = {
"enabled": True, "enabled": True,
}, },
"eps": eps, "eps": eps,
}, }
} }
...@@ -7,7 +7,6 @@ import pickle ...@@ -7,7 +7,6 @@ import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
...@@ -18,7 +17,7 @@ from openfold.data import ( ...@@ -18,7 +17,7 @@ from openfold.data import (
mmcif_parsing, mmcif_parsing,
templates, templates,
) )
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import dict_multimap
import contextlib import contextlib
import tempfile import tempfile
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
...@@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str): ...@@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str):
fasta_file.seek(0) fasta_file.seek(0)
yield fasta_file.name yield fasta_file.name
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
...@@ -296,6 +296,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -296,6 +296,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
...@@ -549,10 +550,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -549,10 +550,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
device=all_chain_features["aatype"].device) device=all_chain_features["aatype"].device)
return all_chain_features return all_chain_features
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
def deterministic_train_filter( def deterministic_train_filter(
chain_data_cache_entry: Any, chain_data_cache_entry: Any,
max_resolution: float = 9., max_resolution: float = 9.,
...@@ -575,6 +576,7 @@ def deterministic_train_filter( ...@@ -575,6 +576,7 @@ def deterministic_train_filter(
return True return True
def deterministic_multimer_train_filter( def deterministic_multimer_train_filter(
mmcif_data_cache_entry, mmcif_data_cache_entry,
max_resolution:float= 9., max_resolution:float= 9.,
...@@ -613,9 +615,10 @@ def deterministic_multimer_train_filter( ...@@ -613,9 +615,10 @@ def deterministic_multimer_train_filter(
return True return True
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any, chain_data_cache_entry: Any,
) -> List[float]: ) -> float:
# Stochastic filters # Stochastic filters
probabilities = [] probabilities = []
...@@ -723,8 +726,8 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -723,8 +726,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datapoint_idx = next(samples) datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx)) self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset): class OpenFoldMultimerDataset(torch.utils.data.Dataset):
""" """
Create a torch Dataset object for multimer training and Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper: add filtering steps described in AlphaFold Multimer's paper:
...@@ -753,7 +756,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -753,7 +756,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
chains = mmcif_data_cache[mmcif_id]['chain_ids'] chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry, if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5): max_resolution=9,
minimum_number_of_residues=5):
selected_idx.append(i) selected_idx.append(i)
return selected_idx return selected_idx
...@@ -781,11 +785,13 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -781,11 +785,13 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
logging.info(f"self.epoch_len is {self.epoch_len}") logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ] self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __call__(self, prots): def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -873,6 +879,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -873,6 +879,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it) return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader): class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs) super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
...@@ -1110,7 +1117,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1110,7 +1117,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
class OpenFoldMultimerDataModule(OpenFoldDataModule): class OpenFoldMultimerDataModule(OpenFoldDataModule):
""" """
Create a datamodule specifically for multimer training Create a datamodule specifically for multimer training
......
...@@ -784,45 +784,6 @@ class DataPipeline: ...@@ -784,45 +784,6 @@ class DataPipeline:
return all_hits return all_hits
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
input_sequence=None,
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
elif (ext =='.sto') and (f.startswith("hmm")):
with open(path,"r") as fp:
hits = parsers.parse_hmmsearch_sto(fp.read(),input_sequence)
all_hits[f] = hits
fp.close()
return all_hits
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
...@@ -879,9 +840,9 @@ class DataPipeline: ...@@ -879,9 +840,9 @@ class DataPipeline:
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir=alignment_dir,
input_sequence, input_sequence=input_sequence,
alignment_index, alignment_index=alignment_index,
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -928,8 +889,9 @@ class DataPipeline: ...@@ -928,8 +889,9 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir=alignment_dir,
alignment_index,input_sequence) input_sequence=input_sequence,
alignment_index=alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
...@@ -976,8 +938,9 @@ class DataPipeline: ...@@ -976,8 +938,9 @@ class DataPipeline:
) )
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir=alignment_dir,
alignment_index,input_sequence input_sequence=input_sequence,
alignment_index=alignment_index,
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -1008,8 +971,9 @@ class DataPipeline: ...@@ -1008,8 +971,9 @@ class DataPipeline:
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir=alignment_dir,
alignment_index,input_sequence input_sequence=input_sequence,
alignment_index=alignment_index,
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -1098,7 +1062,10 @@ class DataPipeline: ...@@ -1098,7 +1062,10 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hits(alignment_dir, alignment_index=None,input_sequence=input_sequence) hits = self._parse_template_hit_files(alignment_dir=alignment_dir,
input_sequence=seq,
alignment_index=None)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
......
...@@ -310,10 +310,10 @@ def fape_loss( ...@@ -310,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss( interface_bb_loss = backbone_loss(
traj=traj, traj=traj,
pair_mask=1. - intra_chain_mask, pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface}, **{**batch, **config.interface_backbone},
) )
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface.weight) + interface_bb_loss * config.interface_backbone.weight)
else: else:
bb_loss = backbone_loss( bb_loss = backbone_loss(
traj=traj, traj=traj,
...@@ -541,8 +541,11 @@ def lddt_loss( ...@@ -541,8 +541,11 @@ def lddt_loss(
cutoff=cutoff, cutoff=cutoff,
eps=eps eps=eps
) )
score = torch.nan_to_num(score,nan=torch.nanmean(score))
# TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score<0] = 0 score[score<0] = 0
score = score.detach() score = score.detach()
bin_index = torch.floor(score * no_bins).long() bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1)) bin_index = torch.clamp(bin_index, max=(no_bins - 1))
...@@ -1233,7 +1236,7 @@ def find_structural_violations( ...@@ -1233,7 +1236,7 @@ def find_structural_violations(
batch["atom14_atom_exists"] batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]] * atomtype_radius[batch["residx_atom14_to_atom37"]]
) )
torch.cuda.memory_summary()
# Compute the between residue clash loss. # Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss( between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
...@@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss( ...@@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True) chains = asym_id.unique()
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int # Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot = torch.nn.functional.one_hot(asym_id.long() - 1,
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float() chain_exists = torch.any(chain_pos_mask, dim=-1).float()
...@@ -1688,6 +1693,7 @@ def chain_center_of_mass_loss( ...@@ -1688,6 +1693,7 @@ def chain_center_of_mass_loss(
loss = masked_mean(loss_mask, losses, dim=(-1, -2)) loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss return loss
# # # #
# below are the functions required for permutations # below are the functions required for permutations
# # # #
...@@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q): ...@@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q):
assert rotation.shape == torch.Size([3,3]) assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda') return rotation.to('cuda')
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor, tgt_atoms: torch.Tensor,
......
...@@ -51,7 +51,8 @@ def get_alphafold_config(): ...@@ -51,7 +51,8 @@ def get_alphafold_config():
return config return config
_param_path = f"openfold/resources/params/params_{consts.model}.npz" dir_path = os.path.dirname(os.path.realpath(__file__))
_param_path = os.path.join(dir_path, "..", f"openfold/resources/params/params_{consts.model}.npz")
_model = None _model = None
......
...@@ -256,7 +256,7 @@ class Template(unittest.TestCase): ...@@ -256,7 +256,7 @@ class Template(unittest.TestCase):
template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer: if consts.is_multimer:
out_repro = model.template_embedder( out_repro_all = model.template_embedder(
template_feats, template_feats,
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
...@@ -267,7 +267,7 @@ class Template(unittest.TestCase): ...@@ -267,7 +267,7 @@ class Template(unittest.TestCase):
inplace_safe=False inplace_safe=False
) )
else: else:
out_repro = model.template_embedder( out_repro_all = model.template_embedder(
template_feats, template_feats,
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
...@@ -277,10 +277,10 @@ class Template(unittest.TestCase): ...@@ -277,10 +277,10 @@ class Template(unittest.TestCase):
inplace_safe=False inplace_safe=False
) )
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment