Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e98c202d
"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "fc4da34502ebb0e32ec06df5c4d150d6d663228b"
Commit
e98c202d
authored
Sep 27, 2021
by
Gustaf Ahdritz
Browse files
Fix IPA bug, add missing dtype specs
parent
ad2e5c97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
10 deletions
+8
-10
openfold/model/structure_module.py
openfold/model/structure_module.py
+8
-10
No files found.
openfold/model/structure_module.py
View file @
e98c202d
...
@@ -297,8 +297,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -297,8 +297,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
)
a
=
a
+
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
*
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
a
=
a
+
(
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
)
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
@@ -759,39 +759,38 @@ class StructureModule(nn.Module):
...
@@ -759,39 +759,38 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
dtype
,
device
):
def
_init_residue_constants
(
self
,
float_
dtype
,
device
):
if
(
self
.
default_frames
is
None
):
if
(
self
.
default_frames
is
None
):
self
.
default_frames
=
torch
.
tensor
(
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
dtype
=
dtype
,
dtype
=
float_
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
group_idx
is
None
):
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
atom_mask
is
None
):
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
dtype
,
dtype
=
float_
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
lit_positions
is
None
):
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
dtype
=
dtype
,
dtype
=
float_
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
...
@@ -802,8 +801,7 @@ class StructureModule(nn.Module):
...
@@ -802,8 +801,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
# arrays
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
_frames_and_literature_positions_to_atom14_pos
(
return
_frames_and_literature_positions_to_atom14_pos
(
t
,
t
,
f
,
f
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment