Unverified Commit 31051cf2 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #336 from dingquanyu/permutation

Added multi-chain permutation steps, multimer datamodule, and training code for multimer
parents 4ca64437 e963726b
...@@ -163,7 +163,7 @@ def model_config( ...@@ -163,7 +163,7 @@ def model_config(
for k,v in multimer_model_config_update['model'].items(): for k,v in multimer_model_config_update['model'].items():
c.model[k] = v c.model[k] = v
for k, v in multimer_model_config_update['loss'].items(): for k,v in multimer_model_config_update['loss'].items():
c.loss[k] = v c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
...@@ -683,8 +683,7 @@ config = mlc.ConfigDict( ...@@ -683,8 +683,7 @@ config = mlc.ConfigDict(
) )
multimer_model_config_update = { multimer_model_config_update = {
"model": { 'model':{"input_embedder": {
"input_embedder": {
"tf_dim": 21, "tf_dim": 21,
"msa_dim": 49, "msa_dim": 49,
#"num_msa": 508, #"num_msa": 508,
...@@ -695,6 +694,20 @@ multimer_model_config_update = { ...@@ -695,6 +694,20 @@ multimer_model_config_update = {
"max_relative_idx": 32, "max_relative_idx": 32,
"use_chain_relative": True, "use_chain_relative": True,
}, },
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": { "template": {
"distogram": { "distogram": {
"min_bin": 3.25, "min_bin": 3.25,
...@@ -828,6 +841,8 @@ multimer_model_config_update = { ...@@ -828,6 +841,8 @@ multimer_model_config_update = {
}, },
"recycle_early_stop_tolerance": 0.5 "recycle_early_stop_tolerance": 0.5
}, },
"recycle_early_stop_tolerance": 0.5
},
"loss": { "loss": {
"distogram": { "distogram": {
"min_bin": 2.3125, "min_bin": 2.3125,
...@@ -903,5 +918,5 @@ multimer_model_config_update = { ...@@ -903,5 +918,5 @@ multimer_model_config_update = {
"enabled": True, "enabled": True,
}, },
"eps": eps, "eps": eps,
} },
} }
This diff is collapsed.
...@@ -784,6 +784,45 @@ class DataPipeline: ...@@ -784,6 +784,45 @@ class DataPipeline:
return all_hits return all_hits
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
input_sequence=None,
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
elif (ext =='.sto') and (f.startswith("hmm")):
with open(path,"r") as fp:
hits = parsers.parse_hmmsearch_sto(fp.read(),input_sequence)
all_hits[f] = hits
fp.close()
return all_hits
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
...@@ -890,8 +929,7 @@ class DataPipeline: ...@@ -890,8 +929,7 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence, alignment_index,input_sequence)
alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
...@@ -939,8 +977,7 @@ class DataPipeline: ...@@ -939,8 +977,7 @@ class DataPipeline:
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence, alignment_index,input_sequence
alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -972,8 +1009,7 @@ class DataPipeline: ...@@ -972,8 +1009,7 @@ class DataPipeline:
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence, alignment_index,input_sequence
alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -1062,7 +1098,7 @@ class DataPipeline: ...@@ -1062,7 +1098,7 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hit_files(alignment_dir, seq, alignment_index=None) hits = self._parse_template_hits(alignment_dir, alignment_index=None,input_sequence=input_sequence)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
......
...@@ -134,8 +134,8 @@ class FeaturePipeline: ...@@ -134,8 +134,8 @@ class FeaturePipeline:
mode: str = "train", mode: str = "train",
is_multimer: bool = False, is_multimer: bool = False,
) -> FeatureDict: ) -> FeatureDict:
if(is_multimer and mode != "predict"): # if(is_multimer and mode != "predict"):
raise ValueError("Multimer mode is not currently trainable") # raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
......
...@@ -34,7 +34,13 @@ from openfold.utils.tensor_utils import ( ...@@ -34,7 +34,13 @@ from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
batched_gather, batched_gather,
) )
import random
from openfold.np import residue_constants as rc
import logging
import procrustes
from openfold.utils.tensor_utils import tensor_tree_map
import gc
logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum( loss = -1 * torch.sum(
...@@ -179,7 +185,13 @@ def backbone_loss( ...@@ -179,7 +185,13 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1]==7:
pred_aff = Rigid.from_tensor_7(traj) pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1]==4:
pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid( pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(), pred_aff.get_trans(),
...@@ -298,10 +310,10 @@ def fape_loss( ...@@ -298,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss( interface_bb_loss = backbone_loss(
traj=traj, traj=traj,
pair_mask=1. - intra_chain_mask, pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone}, **{**batch, **config.interface},
) )
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight) + interface_bb_loss * config.interface.weight)
else: else:
bb_loss = backbone_loss( bb_loss = backbone_loss(
traj=traj, traj=traj,
...@@ -529,9 +541,9 @@ def lddt_loss( ...@@ -529,9 +541,9 @@ def lddt_loss(
cutoff=cutoff, cutoff=cutoff,
eps=eps eps=eps
) )
score = torch.nan_to_num(score,nan=torch.nanmean(score))
score[score<0] = 0
score = score.detach() score = score.detach()
bin_index = torch.floor(score * no_bins).long() bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1)) bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot( lddt_ca_one_hot = torch.nn.functional.one_hot(
...@@ -725,7 +737,11 @@ def tm_loss( ...@@ -725,7 +737,11 @@ def tm_loss(
eps=1e-8, eps=1e-8,
**kwargs, **kwargs,
): ):
# first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1]==7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor) pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1]==4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine): def _points(affine):
...@@ -838,6 +854,7 @@ def between_residue_bond_loss( ...@@ -838,6 +854,7 @@ def between_residue_bond_loss(
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[ ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1 1
] ]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu( c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev c_n_bond_length_error - tolerance_factor_soft * gt_stddev
...@@ -963,7 +980,6 @@ def between_residue_clash_loss( ...@@ -963,7 +980,6 @@ def between_residue_clash_loss(
shape (N, 14) shape (N, 14)
""" """
fp_type = atom14_pred_positions.dtype fp_type = atom14_pred_positions.dtype
# Create the distance matrix. # Create the distance matrix.
# (N, N, 14, 14) # (N, N, 14, 14)
dists = torch.sqrt( dists = torch.sqrt(
...@@ -1217,7 +1233,7 @@ def find_structural_violations( ...@@ -1217,7 +1233,7 @@ def find_structural_violations(
batch["atom14_atom_exists"] batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]] * atomtype_radius[batch["residx_atom14_to_atom37"]]
) )
torch.cuda.memory_summary()
# Compute the between residue clash loss. # Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss( between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
...@@ -1622,7 +1638,7 @@ def chain_center_of_mass_loss( ...@@ -1622,7 +1638,7 @@ def chain_center_of_mass_loss(
asym_id: torch.Tensor, asym_id: torch.Tensor,
clamp_distance: float = -4.0, clamp_distance: float = -4.0,
weight: float = 0.05, weight: float = 0.05,
eps: float = 1e-10 eps: float = 1e-10, **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper. Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
...@@ -1649,9 +1665,9 @@ def chain_center_of_mass_loss( ...@@ -1649,9 +1665,9 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
chains = asym_id.unique() one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float() chain_exists = torch.any(chain_pos_mask, dim=-1).float()
...@@ -1672,6 +1688,216 @@ def chain_center_of_mass_loss( ...@@ -1672,6 +1688,216 @@ def chain_center_of_mass_loss(
loss = masked_mean(loss_mask, losses, dim=(-1, -2)) loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss return loss
# #
# below are the functions required for permutations
# #
def kabsch_rotation(P, Q):
"""
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here:
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P
return:
A 3*3 rotation matrix
"""
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().numpy(),
Q.detach().cpu().numpy(),translate=False,scale=False)
rotation = torch.tensor(rotation.t,dtype=torch.float) # rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda')
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor = None,
):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
assert len(mask.shape) ==1,"mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0)
if mask is not None:
assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
tgt_center,src_center = tgt_center.to('cuda'),src_center.to('cuda')
x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
del tgt_center,src_center,mask
gc.collect()
return r, x.to('cuda')
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos.to('cuda:0')
pred_atom_pos = pred_atom_pos.to('cuda:0')
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
def get_least_asym_entity_or_longest_length(batch):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
entity_asym_count[int(entity_id)] = len(asym_ids)
# Calculate entity length
entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item()
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
# If still multiple entities, return a random one
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym
def greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
true_ca_poses,
true_ca_masks,
):
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
# skip padding
if cur_asym_id == 0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
entity_id = batch["entity_id"][asym_mask][0]
# don't need to align
if (entity_id) == 1:
align.append((i, i))
assert used[i] == False
used[i] = True
continue
cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = torch.inf
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
while best_idx is None:
cropped_pos = true_ca_poses[j]
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
return align
def merge_labels(per_asym_residue_index, labels, align):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
"""
outs = {}
for k, v in labels[0].items():
cur_out = {}
for i, j in align:
label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)==0 or "template" in k:
continue
else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
outs[k] = new_v
return outs
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
...@@ -1679,7 +1905,11 @@ class AlphaFoldLoss(nn.Module): ...@@ -1679,7 +1905,11 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__() super(AlphaFoldLoss, self).__init__()
self.config = config self.config = config
def forward(self, out, batch, _return_breakdown=False): def loss(self, out, batch, _return_breakdown=False):
"""
Rename previous forward() as loss()
so that can be reused in the subclass
"""
if "violation" not in out.keys(): if "violation" not in out.keys():
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
...@@ -1755,7 +1985,6 @@ class AlphaFoldLoss(nn.Module): ...@@ -1755,7 +1985,6 @@ class AlphaFoldLoss(nn.Module):
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
losses[loss_name] = loss.detach().clone() losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cum_loss.detach().clone() losses["unscaled_loss"] = cum_loss.detach().clone()
# Scale the loss by the square root of the minimum of the crop size and # Scale the loss by the square root of the minimum of the crop size and
...@@ -1770,3 +1999,127 @@ class AlphaFoldLoss(nn.Module): ...@@ -1770,3 +1999,127 @@ class AlphaFoldLoss(nn.Module):
return cum_loss return cum_loss
return cum_loss, losses return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False):
if(not _return_breakdown):
cum_loss = self.loss(out,batch,_return_breakdown)
return cum_loss
else:
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
Add multi-chain permutation on top of
AlphaFoldLoss
"""
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].float() # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :].float() for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask)
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
# anchor_pred_pos = anchor_pred_pos.to('cuda')
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_mask =pred_ca_mask[0][asym_mask[0]]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos,anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids ,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
del aligned_true_ca_poses
del r,x
del pred_ca_pos,pred_ca_mask,true_ca_poses,true_ca_masks
del anchor_pred_pos,anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
per_asym_residue_index,
labels,
align,
)
return merged_labels
def forward(self,out,batch,_return_breakdown=False):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
args:
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
features,labels = batch
# first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features)
features['resolution'] = labels[0]['resolution']
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
# features = tensor_tree_map(move_to_cpu,features)
if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss}")
return cum_loss
else:
cum_loss,losses = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
\ No newline at end of file
>query
MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
>tr|A0A2W3M096|A0A2W3M096_STAAU Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=C7Q70_14145 PE=4 SV=1
-------------------MDKKETQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
>tr|A0A0Q9XW80|A0A0Q9XW80_9STAP Uncharacterized protein OS=Staphylococcus sp. NAM3COL9 GN=ACA31_00310 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
>tr|A0A1E5U0W4|A0A1E5U0W4_STAXY Uncharacterized protein OS=Staphylococcus xylosus GN=AST15_04830 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGDTpP
This diff is collapsed.
# STOCKHOLM 1.0
#=GS MGYP000048211747/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000256545448/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000517307434/104-157 DE [subseq from] PL=11 UP=0 BIOMES=0000000011000
#=GS MGYP000971940026/195-224 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/46-74 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/83-111 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000048211747/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000256545448/1-51 -------------------MKKKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000517307434/104-157 ----------------GDLLRQKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000971940026/195-224 ------------------------------VKKSDLGQVTSFLKEVPEGKKQDVLDEVLK----------
MGYP000859660985/46-74 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
MGYP000859660985/83-111 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus schleiferi OX=1295 GN=NP71_p00120 PE=4 SV=1
#=GS tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus (strain USA300) OX=367830 GN=SAUSA300_pUSA030035 PE=4 SV=1
#=GS tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 DE [subseq from] Putative plasmid segregation protein ParR OS=Staphylococcus pseudintermedius OX=283734 GN=parR PE=4 SV=1
#=GS tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A133QXU6|A0A133QXU6_STASI/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus simulans OX=1286 GN=HMPREF3215_00002 PE=4 SV=1
#=GS tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 DE [subseq from] DNA-binding protein OS=Staphylococcus xylosus OX=1288 GN=p11 PE=4 SV=1
#=GS tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 DE [subseq from] ParR OS=Staphylococcus aureus subsp. aureus RN4220 OX=561307 GN=pGO400_p33 PE=4 SV=1
#=GS tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus lugdunensis OX=28035 GN=parR PE=4 SV=1
#=GS tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A418HED5|A0A418HED5_STAGA/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus gallinarum OX=1293 GN=BUY97_07835 PE=4 SV=1
#=GS tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus sp. SKL71187 OX=2497688 GN=EKV43_01520 PE=4 SV=1
#=GS tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus hominis OX=1290 GN=FOB69_12695 PE=4 SV=1
#=GS tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus equorum OX=246432 PE=4 SV=1
#=GS tr|A0A848F022|A0A848F022_STACP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus capitis OX=29388 GN=HHM13_04665 PE=4 SV=1
#=GS tr|O87365|O87365_STAAU/1-51 DE [subseq from] Conserved domain protein OS=Staphylococcus aureus OX=1280 GN=parR PE=1 SV=1
#=GS tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 PE=4 SV=1
#=GS tr|E4PYH1|E4PYH1_STAAU/1-39 DE [subseq from] DUF655 domain-containing protein OS=Staphylococcus aureus OX=1280 GN=SUM_0041p2 PE=4 SV=1
#=GS tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 DE [subseq from] RHH_1 domain-containing protein OS=Staphylococcus sp. NAM3COL9 OX=1667172 GN=ACA31_00310 PE=4 SV=1
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A133QXU6|A0A133QXU6_STASI/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A418HED5|A0A418HED5_STAGA/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A848F022|A0A848F022_STACP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|O87365|O87365_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|E4PYH1|E4PYH1_STAAU/1-39 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREA------------
tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 -------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS UniRef90_A0A141BHY3/1-51 DE [subseq from] DNA-binding protein n=37 Tax=Staphylococcaceae TaxID=90964 RepID=A0A141BHY3_STAXY
#=GS UniRef90_UPI000A061283/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI000A061283
#=GS UniRef90_UPI001E649B27/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI001E649B27
#=GS UniRef90_UPI00201A2D50/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI00201A2D50
#=GS UniRef90_UPI0018EDBA69/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0018EDBA69
#=GS UniRef90_UPI0005E12F5A/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus TaxID=1279 RepID=UPI0005E12F5A
#=GS UniRef90_UPI00207B21F3/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus TaxID=1279 RepID=UPI00207B21F3
#=GS UniRef90_UPI0009836679/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0009836679
#=GS UniRef90_UPI001F5439CD/1-51 DE [subseq from] plasmid segregation protein ParR n=11 Tax=Staphylococcaceae TaxID=90964 RepID=UPI001F5439CD
#=GS UniRef90_UPI000DA9B884/1-51 DE [subseq from] plasmid segregation protein ParR n=3 Tax=Bacillales TaxID=1385 RepID=UPI000DA9B884
#=GS UniRef90_A0A0Q9XW80/1-51 DE [subseq from] RHH_1 domain-containing protein n=1 Tax=Staphylococcus sp. NAM3COL9 TaxID=1667172 RepID=A0A0Q9XW80_9STAP
#=GS UniRef90_UPI001CCC4088/3-48 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Macrococcus armenti TaxID=2875764 RepID=UPI001CCC4088
#=GS UniRef90_UPI0014612D4C/1-49 DE [subseq from] De novo designed WSHC6 n=2 Tax=synthetic construct TaxID=32630 RepID=UPI0014612D4C
#=GS UniRef90_UPI000B802FE5/1-42 DE [subseq from] HEEH_rd4_0097 n=1 Tax=Escherichia coli TaxID=562 RepID=UPI000B802FE5
#=GS UniRef90_UPI001E281CEB/1-54 DE [subseq from] Network hallucinated protein 0738_mod n=1 Tax=synthetic construct TaxID=32630 RepID=UPI001E281CEB
query MGSSHHHHHHSSGLVP-GSHMDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_A0A141BHY3/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_UPI000A061283/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGENP
UniRef90_UPI001E649B27/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI00201A2D50/1-51 --------------------MEKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI0018EDBA69/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALLRYIEEFGENP
UniRef90_UPI0005E12F5A/1-51 --------------------MKKKE-TQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
UniRef90_UPI00207B21F3/1-51 --------------------MSKQE-TNHLLKIKKEDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGSP
UniRef90_UPI0009836679/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGQNP
UniRef90_UPI001F5439CD/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGTP
UniRef90_UPI000DA9B884/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
UniRef90_A0A0Q9XW80/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
UniRef90_UPI001CCC4088/3-48 ----------------------KEV-NQTLLKIDKAEYPEIYDFLENVPRGTKTAHIREALIRYINDIN---
UniRef90_UPI0014612D4C/1-49 MGSSHHHHHHSSGLVPRGSHMTEDE-IRKLRKLLEEAEKKLYKLEDKTRR----------------------
UniRef90_UPI000B802FE5/1-42 MGSSHHHHHHSSGLVPRGSHMDVEEQIRRLEEVLKKNQPVTW------------------------------
UniRef90_UPI001E281CEB/1-54 MGSSHHHHHHSSGLVPRGSHMNIQV-SLQWE---DPKKGKVFSHTVNIPPGGTAEQIA--------------
#=GC RF xxxxxxxxxxxxxxxx.xxxxxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
>query
MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
>tr|A0A2W3M096|A0A2W3M096_STAAU Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=C7Q70_14145 PE=4 SV=1
-------------------MDKKETQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
>tr|A0A0Q9XW80|A0A0Q9XW80_9STAP Uncharacterized protein OS=Staphylococcus sp. NAM3COL9 GN=ACA31_00310 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
>tr|A0A1E5U0W4|A0A1E5U0W4_STAXY Uncharacterized protein OS=Staphylococcus xylosus GN=AST15_04830 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGDTpP
This diff is collapsed.
# STOCKHOLM 1.0
#=GS MGYP000048211747/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000256545448/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000517307434/104-157 DE [subseq from] PL=11 UP=0 BIOMES=0000000011000
#=GS MGYP000971940026/195-224 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/46-74 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/83-111 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000048211747/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000256545448/1-51 -------------------MKKKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000517307434/104-157 ----------------GDLLRQKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000971940026/195-224 ------------------------------VKKSDLGQVTSFLKEVPEGKKQDVLDEVLK----------
MGYP000859660985/46-74 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
MGYP000859660985/83-111 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus schleiferi OX=1295 GN=NP71_p00120 PE=4 SV=1
#=GS tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus (strain USA300) OX=367830 GN=SAUSA300_pUSA030035 PE=4 SV=1
#=GS tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 DE [subseq from] Putative plasmid segregation protein ParR OS=Staphylococcus pseudintermedius OX=283734 GN=parR PE=4 SV=1
#=GS tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A133QXU6|A0A133QXU6_STASI/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus simulans OX=1286 GN=HMPREF3215_00002 PE=4 SV=1
#=GS tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 DE [subseq from] DNA-binding protein OS=Staphylococcus xylosus OX=1288 GN=p11 PE=4 SV=1
#=GS tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 DE [subseq from] ParR OS=Staphylococcus aureus subsp. aureus RN4220 OX=561307 GN=pGO400_p33 PE=4 SV=1
#=GS tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus lugdunensis OX=28035 GN=parR PE=4 SV=1
#=GS tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A418HED5|A0A418HED5_STAGA/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus gallinarum OX=1293 GN=BUY97_07835 PE=4 SV=1
#=GS tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus sp. SKL71187 OX=2497688 GN=EKV43_01520 PE=4 SV=1
#=GS tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus hominis OX=1290 GN=FOB69_12695 PE=4 SV=1
#=GS tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus equorum OX=246432 PE=4 SV=1
#=GS tr|A0A848F022|A0A848F022_STACP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus capitis OX=29388 GN=HHM13_04665 PE=4 SV=1
#=GS tr|O87365|O87365_STAAU/1-51 DE [subseq from] Conserved domain protein OS=Staphylococcus aureus OX=1280 GN=parR PE=1 SV=1
#=GS tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 PE=4 SV=1
#=GS tr|E4PYH1|E4PYH1_STAAU/1-39 DE [subseq from] DUF655 domain-containing protein OS=Staphylococcus aureus OX=1280 GN=SUM_0041p2 PE=4 SV=1
#=GS tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 DE [subseq from] RHH_1 domain-containing protein OS=Staphylococcus sp. NAM3COL9 OX=1667172 GN=ACA31_00310 PE=4 SV=1
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A133QXU6|A0A133QXU6_STASI/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A418HED5|A0A418HED5_STAGA/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A848F022|A0A848F022_STACP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|O87365|O87365_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|E4PYH1|E4PYH1_STAAU/1-39 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREA------------
tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 -------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS UniRef90_A0A141BHY3/1-51 DE [subseq from] DNA-binding protein n=37 Tax=Staphylococcaceae TaxID=90964 RepID=A0A141BHY3_STAXY
#=GS UniRef90_UPI000A061283/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI000A061283
#=GS UniRef90_UPI001E649B27/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI001E649B27
#=GS UniRef90_UPI00201A2D50/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI00201A2D50
#=GS UniRef90_UPI0018EDBA69/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0018EDBA69
#=GS UniRef90_UPI0005E12F5A/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus TaxID=1279 RepID=UPI0005E12F5A
#=GS UniRef90_UPI00207B21F3/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus TaxID=1279 RepID=UPI00207B21F3
#=GS UniRef90_UPI0009836679/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0009836679
#=GS UniRef90_UPI001F5439CD/1-51 DE [subseq from] plasmid segregation protein ParR n=11 Tax=Staphylococcaceae TaxID=90964 RepID=UPI001F5439CD
#=GS UniRef90_UPI000DA9B884/1-51 DE [subseq from] plasmid segregation protein ParR n=3 Tax=Bacillales TaxID=1385 RepID=UPI000DA9B884
#=GS UniRef90_A0A0Q9XW80/1-51 DE [subseq from] RHH_1 domain-containing protein n=1 Tax=Staphylococcus sp. NAM3COL9 TaxID=1667172 RepID=A0A0Q9XW80_9STAP
#=GS UniRef90_UPI001CCC4088/3-48 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Macrococcus armenti TaxID=2875764 RepID=UPI001CCC4088
#=GS UniRef90_UPI0014612D4C/1-49 DE [subseq from] De novo designed WSHC6 n=2 Tax=synthetic construct TaxID=32630 RepID=UPI0014612D4C
#=GS UniRef90_UPI000B802FE5/1-42 DE [subseq from] HEEH_rd4_0097 n=1 Tax=Escherichia coli TaxID=562 RepID=UPI000B802FE5
#=GS UniRef90_UPI001E281CEB/1-54 DE [subseq from] Network hallucinated protein 0738_mod n=1 Tax=synthetic construct TaxID=32630 RepID=UPI001E281CEB
query MGSSHHHHHHSSGLVP-GSHMDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_A0A141BHY3/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_UPI000A061283/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGENP
UniRef90_UPI001E649B27/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI00201A2D50/1-51 --------------------MEKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI0018EDBA69/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALLRYIEEFGENP
UniRef90_UPI0005E12F5A/1-51 --------------------MKKKE-TQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
UniRef90_UPI00207B21F3/1-51 --------------------MSKQE-TNHLLKIKKEDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGSP
UniRef90_UPI0009836679/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGQNP
UniRef90_UPI001F5439CD/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGTP
UniRef90_UPI000DA9B884/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
UniRef90_A0A0Q9XW80/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
UniRef90_UPI001CCC4088/3-48 ----------------------KEV-NQTLLKIDKAEYPEIYDFLENVPRGTKTAHIREALIRYINDIN---
UniRef90_UPI0014612D4C/1-49 MGSSHHHHHHSSGLVPRGSHMTEDE-IRKLRKLLEEAEKKLYKLEDKTRR----------------------
UniRef90_UPI000B802FE5/1-42 MGSSHHHHHHSSGLVPRGSHMDVEEQIRRLEEVLKKNQPVTW------------------------------
UniRef90_UPI001E281CEB/1-54 MGSSHHHHHHSSGLVPRGSHMNIQV-SLQWE---DPKKGKVFSHTVNIPPGGTAEQIA--------------
#=GC RF xxxxxxxxxxxxxxxx.xxxxxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
{"2q2k": {"release_date": "2008-02-05", "chain_ids": ["A", "B"], "seqs": ["MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP", "MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP"], "no_chains": 2, "resolution": 3.0}}
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import shutil
import pickle
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import unittest
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts
import logging
logger = logging.getLogger(__name__)
import os
class TestMultimerDataModule(unittest.TestCase):
def setUp(self):
"""
Set up model config
use model_1_multimer_v3 for now
"""
self.config = model_config(
"model_1_multimer_v3",
train=True,
low_prec=True)
self.data_module = OpenFoldMultimerDataModule(
config=self.config.data,
batch_seed=42,
train_epoch_len=100,
template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/"),
kalign_binary_path=shutil.which('kalign'),
train_mmcif_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_mmcifs_cache.json"),
train_chain_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_chain_data_cache.json"),
)
# setup model
self.c = model_config(consts.model, train=True)
self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
self.model = AlphaFold(self.c)
self.multimer_loss = AlphaFoldMultimerLoss(self.c.loss)
def testPrepareData(self):
self.data_module.prepare_data()
self.data_module.setup()
train_dataset = self.data_module.train_dataset
all_chain_features,ground_truth = train_dataset[1]
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
with torch.no_grad():
out = self.model(all_chain_features)
self.multimer_loss(out,(all_chain_features,ground_truth))
\ No newline at end of file
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import (
OpenFoldDataModule, OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader, DummyDataLoader,
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
...@@ -27,7 +27,7 @@ from openfold.utils.callbacks import ( ...@@ -27,7 +27,7 @@ from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, AlphaFoldMultimerLoss,lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
...@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule):
) )
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.config.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.config.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.config.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
if(self.ema.device != all_chain_features["aatype"].device):
self.ema.to(all_chain_features["aatype"].device)
# Run the model
outputs = self(all_chain_features)
# Compute loss
loss = self.loss(
outputs, (all_chain_features,ground_truth), _return_breakdown=False
)
# Log it
self._log(loss, all_chain_features, outputs)
return loss
def validation_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(all_chain_features)
# Compute loss and other metrics
all_chain_features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, all_chain_features, _return_breakdown=True
)
self._log(loss_breakdown, all_chain_features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
...@@ -266,7 +329,9 @@ def main(args): ...@@ -266,7 +329,9 @@ def main(args):
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
if "multimer" in args.config_preset:
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
...@@ -293,6 +358,13 @@ def main(args): ...@@ -293,6 +358,13 @@ def main(args):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle") #data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
else:
data_module = OpenFoldDataModule( data_module = OpenFoldDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
...@@ -417,6 +489,10 @@ if __name__ == "__main__": ...@@ -417,6 +489,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''' filtered by the release date of the target'''
) )
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument( parser.add_argument(
"--distillation_data_dir", type=str, default=None, "--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files" help="Directory containing training PDB files"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment