Commit 9eda0b43 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix incorrect types in data pipeline

parent 05e08750
......@@ -60,7 +60,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int32
new_order_list, dtype=torch.int64
).expand(num_templates, -1)
protein['template_aatype'] = torch.gather(
new_order, 1, index=protein['template_aatype']
......@@ -512,13 +512,13 @@ def make_atom14_masks(protein):
)
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = torch.index_select(
restype_atom37_to_atom14, 0, protein['aatype']
)
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32)
......
......@@ -73,7 +73,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = {
'template_aatype': np.float32,
'template_aatype': np.int64,
'template_all_atom_masks': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': np.object,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment