Unverified Commit 6f78792d authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge pull request #6 from dingquanyu/multimer-dataloader

Added Multimer dataloader and training scripts
parents 85185efd 5fc80134
# Permutation code README # Permutation code README
## Overview: ## Overview: before running training script
NB: before running the test codes,please download the procrustes package first: NB: before running the test codes,please download the procrustes package first:
from https://github.com/theochem/procrustes from https://github.com/theochem/procrustes
To test the permutation codes: Make sure that the product of running ```scripts/generate_mmcif_cache.py``` is ready and available in ```tests/test_data```
I have uploaded the json file to owncloud [here](https://oc.embl.de/index.php/s/wVUwc1IHiJUt9sP)
To test the train multimer codes:
```bash ```bash
python -m unittest tests/test_permutation.py python3 train_openfold.py /g/alphafold/AlphaFold_DBs/pdb_mmcif/mmcif_files/ \
tests/test_data/alignments/ \
/g/alphafold/AlphaFold_DBs/pdb_mmcif/mmcif_files/ \
/scratch/gyu/train_openfold_output \
2500-01-01 \
--train_mmcif_data_cache_path=/tests/test_data/train_mmcifs_cache.json \
--template_release_dates_cache_path=tests/test_data/mmcif_cache.json \
--config_preset=model_1_multimer_v3 --seed=42 --gpus=1
``` ```
The files that has been changed is:
[```openfold/utils/loss.py```](https://github.com/dingquanyu/openfold/blob/permutation/openfold/utils/loss.py), in which the forward function is modified in
original ```AlphaFoldLoss``` class;
create a child class called ```AlphaFoldMultimerLoss``` that not only inherited all the loss calculations but also
has multi-chain permutation codes;
some loss calculations have to be modified e.g. in ```fape``` loss, ```tm``` loss calculations, an extra validation was added to check if the input tensor belongs to tensor_7 or tensor_4*4 for example : https://github.com/dingquanyu/openfold/blob/02b008dc4b8c2e9e680826444c605297eeb9ffb4/openfold/utils/loss.py#L190-L193 Unlike training monomer, chain_cache_data is not required but the train_mmcifs_cache is required. In this case, I selected these 9 mmcifs that are already in the previous test_data folder as a training set. ```./tests/test_data/train_mmcifs_cache.json``` in the command above record the information of these 9 structures and is needed to run the training code.
[```openfold/config.py```](https://github.com/dingquanyu/openfold/blob/permutation/openfold/config.py) has seen a couple of modifications as well. Some namings were wrong and previous script forgot to update config.loss with multimer_model_config_update ## Issues
Testing the codes on cpu works fine but when running it on a gpu, it causes ```RuntimeError: CUDA error: device-side assert triggered``` at unexpected steps.
For example, this error was raised while calculating the best rotation matrix that aligns selected anchors during multi-chain permutation steps, I have to use
```torch.masked_select``` and ```torch.index_select``` in https://github.com/dingquanyu/openfold/blob/a1ef4c8fa99da5cff9501051de71be440ca3cedf/openfold/utils/loss.py#L2043 and https://github.com/dingquanyu/openfold/blob/a1ef4c8fa99da5cff9501051de71be440ca3cedf/openfold/utils/loss.py#L2060 instead of simply slicing the matrix like ```matrix[index]```.
These files are newly added: Later on the same ```CUDA error: device-side assert triggered``` error was raised while adding dimensions to the ```atom_pred_positions``` in https://github.com/dingquanyu/openfold/blob/a1ef4c8fa99da5cff9501051de71be440ca3cedf/openfold/utils/loss.py#L989
[```tests/test_permutation.py```](https://github.com/dingquanyu/openfold/blob/permutation/tests/test_permutation.py): A unittest script
that tests permutation functions.
[```tests/test_data/label_1.pkl```](https://github.com/dingquanyu/openfold/blob/permutation/tests/test_data/label_1.pkl) I've dumped the matrices in a pickle and load them individually outside the programme to a GPU then the indexing steps worked without the CUDA error.
and [```tests/test_data/label_2.pkl```](https://github.com/dingquanyu/openfold/blob/permutation/tests/test_data/label_2.pkl) are 2 fake ground truth structures.
```label_1.pkl``` has 9 residues and ```label_2.pkl``` has 13 residues
### Notes
29/06/23 Fill NaN in the lddt scores with the matrix mean for now because the test data are randomly generated and it gives NaN in the lddt score somehow.
**Delete** this step before merging to Multimer branch
...@@ -24,8 +24,7 @@ import tempfile ...@@ -24,8 +24,7 @@ import tempfile
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
import logging
logger = logging.getLogger(__name__)
@contextlib.contextmanager @contextlib.contextmanager
def temp_fasta_file(sequence_str): def temp_fasta_file(sequence_str):
...@@ -471,8 +470,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -471,8 +470,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
print(f"mmcif_id is :{mmcif_id}")
chains = self.mmcif_data_cache[mmcif_id]['chain_ids'] chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
seqs = self.mmcif_data_cache[mmcif_id]['seqs'] seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = "" fasta_str = ""
for c,s in zip(chains,seqs): for c,s in zip(chains,seqs):
...@@ -779,7 +778,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -779,7 +778,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = self.filter_samples(dataset_idx) selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len: if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx) self.epoch_len = len(selected_idx)
print(f"self.epoch_len is {self.epoch_len}") logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ] self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
...@@ -874,51 +873,25 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -874,51 +873,25 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it) return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(OpenFoldDataLoader): class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, config=config, stage=stage, generator=generator, **kwargs) super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
self.config = config
self.stage = stage
def _add_batch_properties(self, batch): if(generator is None):
samples = torch.multinomial( generator = torch.Generator()
self.prop_probs_tensor,
num_samples=1, # 1 per row self.generator = generator
replacement=True, print('initialised a multimer dataloader')
generator=self.generator) def __iter__(self):
it = super().__iter__()
def process_samples(batch,samples):
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"): def _batch_prop_gen(iterator):
no_recycling = sample for batch in iterator:
yield batch
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch return _batch_prop_gen(it)
all_chain_features,ground_truth =batch
all_chain_features = process_samples(all_chain_features,samples)
ground_truth = [process_samples(i,samples) for i in ground_truth]
return (all_chain_features,ground_truth)
class OpenFoldDataModule(pl.LightningDataModule): class OpenFoldDataModule(pl.LightningDataModule):
...@@ -1259,15 +1232,12 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1259,15 +1232,12 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
dl = OpenFoldMultimerDataLoader( dl = torch.utils.data.DataLoader(
dataset, dataset,
config=self.config, batch_size=1,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=self.config.data_module.data_loaders.num_workers,
) )
print(f"generated training dataloader")
return dl return dl
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
......
...@@ -535,7 +535,10 @@ class AlphaFold(nn.Module): ...@@ -535,7 +535,10 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) or early_stop is_final_iter = cycle_no == (num_iters - 1) or early_stop
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): enable_grad= is_grad_enabled and is_final_iter
if (type(enable_grad)!=bool) and (type(enable_grad)==torch.Tensor):
enable_grad = enable_grad.item()
with torch.set_grad_enabled(enable_grad):
if is_final_iter: if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766) # Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
......
...@@ -980,7 +980,6 @@ def between_residue_clash_loss( ...@@ -980,7 +980,6 @@ def between_residue_clash_loss(
shape (N, 14) shape (N, 14)
""" """
fp_type = atom14_pred_positions.dtype fp_type = atom14_pred_positions.dtype
# Create the distance matrix. # Create the distance matrix.
# (N, N, 14, 14) # (N, N, 14, 14)
dists = torch.sqrt( dists = torch.sqrt(
...@@ -1234,7 +1233,7 @@ def find_structural_violations( ...@@ -1234,7 +1233,7 @@ def find_structural_violations(
batch["atom14_atom_exists"] batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]] * atomtype_radius[batch["residx_atom14_to_atom37"]]
) )
torch.cuda.memory_summary()
# Compute the between residue clash loss. # Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss( between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
...@@ -1710,36 +1709,30 @@ def kabsch_rotation(P, Q): ...@@ -1710,36 +1709,30 @@ def kabsch_rotation(P, Q):
""" """
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
finished_rotation = False rotation = procrustes.rotational(P.detach().cpu().numpy(),
while not finished_rotation: Q.detach().cpu().numpy(),translate=False,scale=False)
# rotation = torch.tensor(rotation.t,dtype=torch.float) # rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
# Add a try-except block cuz sometimes SVD fails to converge and crashes the programme
# Will continue trying SVD until the optimal rotaion is calculated
# #
try:
# first need to load P and Q to cpu otherwise cannot extract the numpy matrices
rotation = procrustes.rotational(P.to('cpu').numpy(),
Q.to('cpu').numpy(),translate=True)
finished_rotation = True
except:
print(f"svd failed.")
import sys
sys.exit()
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3]) assert rotation.shape == torch.Size([3,3])
return rotation return rotation.to('cuda')
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor, tgt_atoms: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
): ):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
if torch.isnan(src_atoms).any(): assert len(mask.shape) ==1,"mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
# #
# sometimes using fake test inputs generates NaN in the predicted atom positions # sometimes using fake test inputs generates NaN in the predicted atom positions
# # # #
logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0) src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0)
if mask is not None: if mask is not None:
...@@ -1749,15 +1742,15 @@ def get_optimal_transform( ...@@ -1749,15 +1742,15 @@ def get_optimal_transform(
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float() src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms tgt_atoms = src_atoms
else: else:
src_atoms = src_atoms.to('cuda:0')[mask, :] src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms.to('cuda:0')[mask, :] tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms,tgt_atoms) r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms, del src_atoms,tgt_atoms,
gc.collect() gc.collect()
tgt_center,src_center = tgt_center.to('cuda:0'),src_center.to('cuda:0') tgt_center,src_center = tgt_center.to('cuda'),src_center.to('cuda')
x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu') x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
del tgt_center,src_center,mask del tgt_center,src_center,mask
...@@ -2047,7 +2040,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2047,7 +2040,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask)
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
...@@ -2060,22 +2053,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2060,22 +2053,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] # anchor_pred_pos = anchor_pred_pos.to('cuda')
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
input_mask = (anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool() anchor_pred_mask =pred_ca_mask[0][asym_mask[0]]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos, anchor_pred_pos,anchor_true_pos[0],
anchor_pred_pos,mask=input_mask mask=input_mask[0]
) )
del input_mask # just to save memory del input_mask # just to save memory
del anchor_pred_mask del anchor_pred_mask
del anchor_true_mask del anchor_true_mask
gc.collect() gc.collect()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -2089,6 +2084,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2089,6 +2084,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses del aligned_true_ca_poses
del r,x del r,x
del pred_ca_pos,pred_ca_mask,true_ca_poses,true_ca_masks
del anchor_pred_pos,anchor_true_pos
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}") print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels( merged_labels = merge_labels(
...@@ -2117,7 +2114,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2117,7 +2114,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
permutated_labels.pop('aatype') permutated_labels.pop('aatype')
features.update(permutated_labels) features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu')) move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features) # features = tensor_tree_map(move_to_cpu,features)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown) cum_loss = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss}") print(f"cum_loss: {cum_loss}")
......
# STOCKHOLM 1.0
#=GS MGYP000184479417/51-87 DE [subseq from] PL=10 UP=0 BIOMES=0000000010100
#=GS MGYP001032527956/9-36 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/100-127 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/130-155 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/177-204 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
query ALKKHHENEISHHAKEIERLQKEIERHKQSIKKLKQSEDDD
MGYP000184479417/51-87 ALKKHHEEEIVHHKKEIERLQKEIERHKQKIKMLKHD----
MGYP001032527956/9-36 -------SKISEMTNEISRLNSEIENYKQQIESLN------
MGYP001032527956/100-127 -------NKINDLNNEISRLNSEIENYKQQIETLN------
MGYP001032527956/130-155 ---------INDLTNEISRLNSEIENYKQQIETLN------
MGYP001032527956/177-204 -------NKINDLNNEISRLNSEIDNYKQQIETLN------
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
This diff is collapsed.
# STOCKHOLM 1.0
#=GS MGYP000184479417/51-87 DE [subseq from] PL=10 UP=0 BIOMES=0000000010100
#=GS MGYP001032527956/9-36 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/100-127 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/130-155 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
#=GS MGYP001032527956/177-204 DE [subseq from] PL=11 UP=0 BIOMES=0000000000001
query ALKKHHENEISHHAKEIERLQKEIERHKQSIKKLKQSEDDD
MGYP000184479417/51-87 ALKKHHEEEIVHHKKEIERLQKEIERHKQKIKMLKHD----
MGYP001032527956/9-36 -------SKISEMTNEISRLNSEIENYKQQIESLN------
MGYP001032527956/100-127 -------NKINDLNNEISRLNSEIENYKQQIETLN------
MGYP001032527956/130-155 ---------INDLTNEISRLNSEIENYKQQIETLN------
MGYP001032527956/177-204 -------NKINDLNNEISRLNSEIDNYKQQIETLN------
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
This diff is collapsed.
{"3zee": {"release_date": "2013-10-16", "chain_ids": ["A"], "seqs": ["SEFKVTVCFGRTRVVVPCGDGRMKVFSLIQQAVTRYRKAVAKDPNYWIQVHRLEHGDGGILDLDDILCDVADDKDRLVAVFDEQ"], "no_chains": 1, "resolution": 6.1}, "4i6p": {"release_date": "2013-07-17", "chain_ids": ["A", "B"], "seqs": ["GPGSEFKVTVCFGRTRVVVPCGDGRMKVFSLIQQAVTRYRKAVAKDPNYWIQVHRLEHGDGGILDLDDILCDVADDKDRLVAVFDEQD", "GPGSEFKVTVCFGRTRVVVPCGDGRMKVFSLIQQAVTRYRKAVAKDPNYWIQVHRLEHGDGGILDLDDILCDVADDKDRLVAVFDEQD"], "no_chains": 2, "resolution": 2.9}, "5kc1": {"release_date": "2016-10-05", "chain_ids": ["C", "D", "A", "B", "G", "H", "K", "L", "E", "F", "I", "J"], "seqs": ["MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN", "MSTLAEVYTIIEDAEQECRKGDFTNAKAKYQEAIEVLGPQNENLSQNKLSSDVTQAIDLLKQDITAKIQELELLIEKQSSEENNIGMVNNNMLIGSVILNNKSPINGISNARNWDNPAYQDTLSPINDPLLMSILNRLQFNLNNDIQLKTEGGKNSKNSEMKINLRLEQFKKELVLYEQKKFKEYGMKIDEITKENKKLANEIGRLRERWDSLVESAKQRRDKQKN"], "no_chains": 12, "resolution": 2.2}, "2crb": {"release_date": "2005-11-20", "chain_ids": ["A"], "seqs": ["GSSGSSGMEGPLNLAHQQSRRADRLLAAGKYEEAISCHRKATTYLSEAMKLTESEQAHLSLELQRDSHMKQLLLIQERWKRAKREERLKAHSGPSSG"], "no_chains": 1, "resolution": 0.0}, "2q2k": {"release_date": "2008-02-05", "chain_ids": ["A", "B"], "seqs": ["MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP", "MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP"], "no_chains": 2, "resolution": 3.0}, "1hf9": {"release_date": "2001-05-31", "chain_ids": ["A", "B"], "seqs": ["ALKKHHENEISHHAKEIERLQKEIERHKQSIKKLKQSEDDD", "ALKKHHENEISHHAKEIERLQKEIERHKQSIKKLKQSEDDD"], "no_chains": 2, "resolution": 0.0}, "3u8v": {"release_date": "2011-11-09", "chain_ids": ["A", "B"], "seqs": ["SGHTAHVDEAVKHAEEAVAHGKEGHTDQLLEHAKESLTHAKAASEAGGNTHVGHGIKHLEDAIKHGEEGHVGVATKHAQEAIEHLRASEHKSH", "SGHTAHVDEAVKHAEEAVAHGKEGHTDQLLEHAKESLTHAKAASEAGGNTHVGHGIKHLEDAIKHGEEGHVGVATKHAQEAIEHLRASEHKSH"], "no_chains": 2, "resolution": 1.9}, "1psm": {"release_date": "1995-02-07", "chain_ids": ["A"], "seqs": ["EAYKKAKQASQDAEQAAKDAENASKEAEEAAKEAVNLK"], "no_chains": 1, "resolution": 0.0}, "4zey": {"release_date": "2015-05-06", "chain_ids": ["A"], "seqs": ["GMEGPLNLAHQQSRRADRLLAAGKYEEAISCHKKAAAYLSEAMKLTQSEQAHLSLELQRDSHMKQLLLIQERWKRAQREERLKA"], "no_chains": 1, "resolution": 1.5}}
\ No newline at end of file
...@@ -16,18 +16,18 @@ import torch ...@@ -16,18 +16,18 @@ import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import (
OpenFoldDataModule, OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader, DummyDataLoader,
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, AlphaFoldMultimerLoss,lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
...@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule):
) )
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.config.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.config.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.config.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
if(self.ema.device != all_chain_features["aatype"].device):
self.ema.to(all_chain_features["aatype"].device)
# Run the model
outputs = self(all_chain_features)
# Compute loss
loss = self.loss(
outputs, (all_chain_features,ground_truth), _return_breakdown=False
)
# Log it
self._log(loss, all_chain_features, outputs)
return loss
def validation_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(all_chain_features)
# Compute loss and other metrics
all_chain_features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, all_chain_features, _return_breakdown=True
)
self._log(loss_breakdown, all_chain_features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
...@@ -266,8 +329,10 @@ def main(args): ...@@ -266,8 +329,10 @@ def main(args):
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
if "multimer" in args.config_preset:
model_module = OpenFoldWrapper(config) model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
...@@ -293,11 +358,18 @@ def main(args): ...@@ -293,11 +358,18 @@ def main(args):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle") #data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule( if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
...@@ -417,6 +489,10 @@ if __name__ == "__main__": ...@@ -417,6 +489,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''' filtered by the release date of the target'''
) )
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument( parser.add_argument(
"--distillation_data_dir", type=str, default=None, "--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files" help="Directory containing training PDB files"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment