Commit c16e400e authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added feature transformations for making denser (14 atom) dimensions instead of 37 atom.

parent bc43ccf0
......@@ -414,3 +414,55 @@ def crop_templates(protein, max_templates):
if k.startswith('template_'):
protein[k] = v[:max_templates]
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0) for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to tell which atom is there in this residue
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_to_atom37 = torch.tensor(restype_atom14_to_atom37, dtype=torch.int32)
restype_atom37_to_atom14 = torch.tensor(restype_atom37_to_atom14, dtype=torch.int32)
restype_atom14_mask = torch.tensor(restype_atom14_mask, dtype=torch.float32)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = torch.index_select(restype_atom14_to_atom37, 0, protein['aatype'])
residx_atom14_mask = torch.index_select(restype_atom14_mask, 0, protein['aatype'])
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = torch.index_select(restype_atom37_to_atom14, 0, protein['aatype'])
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = torch.index_select(restype_atom37_mask, 0, protein['aatype'])
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment