Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
50078a62
Commit
50078a62
authored
Apr 26, 2022
by
Gustaf Ahdritz
Browse files
Add type check to attention kernel
parent
aada2a46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
1 deletion
+6
-1
openfold/utils/kernel/attention_core.py
openfold/utils/kernel/attention_core.py
+5
-0
tests/test_loss.py
tests/test_loss.py
+1
-1
No files found.
openfold/utils/kernel/attention_core.py
View file @
50078a62
...
...
@@ -20,11 +20,16 @@ import torch
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
SUPPORTED_DTYPES
=
[
torch
.
float32
,
torch
.
bfloat16
]
class
AttentionCoreFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
bias_1
=
None
,
bias_2
=
None
):
if
(
bias_1
is
None
and
bias_2
is
not
None
):
raise
ValueError
(
"bias_1 must be specified before bias_2"
)
if
(
q
.
dtype
not
in
SUPPORTED_DTYPES
):
raise
ValueError
(
"Unsupported datatype"
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
...
...
tests/test_loss.py
View file @
50078a62
...
...
@@ -671,7 +671,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_backbone_loss
(
self
):
def
test_backbone_loss
_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment