Commit e28c2828 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added feature transformations for creating pseudo-beta position for glycine.

parent f7bb6999
......@@ -254,3 +254,31 @@ def make_msa_mask(protein):
protein['msa_mask'] = torch.ones(protein['msa'].shape, dtype=torch.float32)
protein['msa_row_mask'] = torch.ones(protein['msa'].shape[0], dtype=torch.float32)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=''):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment