Commit e98c202d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix IPA bug, add missing dtype specs

parent ad2e5c97
......@@ -297,8 +297,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1))
a = a * math.sqrt(1. / (3 * self.c_hidden))
a = a + (math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -759,39 +759,38 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, dtype, device):
def _init_residue_constants(self, float_dtype, device):
if(self.default_frames is None):
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=dtype,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if(self.group_idx is None):
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.atom_mask is None):
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=dtype,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if(self.lit_positions is None):
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=dtype,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.dtype, f.device)
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
......@@ -802,8 +801,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self._init_residue_constants(f.dtype, f.device)
self._init_residue_constants(t.rots.dtype, t.rots.device)
return _frames_and_literature_positions_to_atom14_pos(
t,
f,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment