Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
862dab6f
"...git@developer.sourcefind.cn:OpenDAS/mmdeploy.git" did not exist on "e4fb2aa4ea2a6347df67473bb2f02169510a1cab"
Commit
862dab6f
authored
Nov 19, 2021
by
Gustaf Ahdritz
Browse files
Fix bugs from previous commit
parent
34e9363c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
8 deletions
+5
-8
openfold/model/template.py
openfold/model/template.py
+1
-1
openfold/model/torchscript.py
openfold/model/torchscript.py
+0
-1
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+3
-5
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+1
-1
No files found.
openfold/model/template.py
View file @
862dab6f
...
@@ -315,7 +315,7 @@ class TemplatePairStack(nn.Module):
...
@@ -315,7 +315,7 @@ class TemplatePairStack(nn.Module):
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
mask
=
mask
.
expand
(
*
expand_idx
)
(
t
,
)
=
checkpoint_blocks
(
t
,
=
checkpoint_blocks
(
blocks
=
[
blocks
=
[
partial
(
partial
(
b
,
b
,
...
...
openfold/model/torchscript.py
View file @
862dab6f
...
@@ -138,7 +138,6 @@ def _trace_module(module, batch_dims=None):
...
@@ -138,7 +138,6 @@ def _trace_module(module, batch_dims=None):
)
)
)
)
}
}
module
=
OPM
(
module
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"tracing is not supported for modules of type
{
type
(
module
)
}
"
f
"tracing is not supported for modules of type
{
type
(
module
)
}
"
...
...
openfold/model/triangular_multiplicative_update.py
View file @
862dab6f
...
@@ -52,7 +52,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -52,7 +52,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
def
_combine_projections
(
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -94,8 +94,7 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
...
@@ -94,8 +94,7 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
"""
def
_combine_projections
(
def
_combine_projections
(
self
,
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
):
...
@@ -113,8 +112,7 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
...
@@ -113,8 +112,7 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
def
_combine_projections
(
def
_combine_projections
(
self
,
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
862dab6f
...
@@ -32,7 +32,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -32,7 +32,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c
=
11
c
=
11
outgoing
=
True
outgoing
=
True
tm
=
TriangleMultiplicati
veUpdate
(
tm
=
TriangleMultiplicati
onOutgoing
(
c_z
,
c_z
,
c
,
c
,
outgoing
,
outgoing
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment