Commit bc43ccf0 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added feature transformations for making MSA features.

parent 0b32f2d5
......@@ -368,6 +368,42 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for compatibility with domain datasets.
has_break = torch.clip(protein['between_segment_residues'].to(torch.float32), 0, 1)
aatype_1hot = make_one_hot(protein['aatype'], 21)
target_feat = [
torch.unsqueeze(has_break, dim=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = make_one_hot(protein['msa'], 23)
has_deletion = torch.clip(protein['deletion_matrix'], 0., 1.)
deletion_value = torch.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
msa_feat = [
msa_1hot,
torch.unsqueeze(has_deletion, dim=-1),
torch.unsqueeze(deletion_value, dim=-1),
]
if 'cluster_profile' in protein:
deletion_mean_value = (torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([protein['cluster_profile'],
torch.unsqueeze(deletion_mean_value, dim=-1),
])
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = torch.clip(protein['extra_deletion_matrix'], 0., 1.)
protein['extra_deletion_value'] = torch.atan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
protein['msa_feat'] = torch.cat(msa_feat, dim=-1)
protein['target_feat'] = torch.cat(target_feat, dim=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment