"vscode:/vscode.git/clone" did not exist on "754d2ba82eaf251cfb981e3d4b9c2dbe4962bf08"
Commit e98c202d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix IPA bug, add missing dtype specs

parent ad2e5c97
...@@ -297,8 +297,8 @@ class InvariantPointAttention(nn.Module): ...@@ -297,8 +297,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a = a + math.sqrt(1. / (3 * self.c_hidden)) a = a * math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1)) a = a + (math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -759,39 +759,38 @@ class StructureModule(nn.Module): ...@@ -759,39 +759,38 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, dtype, device): def _init_residue_constants(self, float_dtype, device):
if(self.default_frames is None): if(self.default_frames is None):
self.default_frames = torch.tensor( self.default_frames = torch.tensor(
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.group_idx is None): if(self.group_idx is None):
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.atom_mask is None): if(self.atom_mask is None):
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.lit_positions is None): if(self.lit_positions is None):
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.dtype, f.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames) return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
...@@ -802,8 +801,7 @@ class StructureModule(nn.Module): ...@@ -802,8 +801,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these # TODO: Maybe this stuff should be done on CPU instead (so these
# arrays # arrays
self._init_residue_constants(f.dtype, f.device) self._init_residue_constants(t.rots.dtype, t.rots.device)
return _frames_and_literature_positions_to_atom14_pos( return _frames_and_literature_positions_to_atom14_pos(
t, t,
f, f,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment