Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a65d5009
Unverified
Commit
a65d5009
authored
Jun 03, 2022
by
shenggan
Committed by
GitHub
Jun 03, 2022
Browse files
use template for fused softmax & add unittest for fused softmax (#26)
parent
771d4b83
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
394 deletions
+171
-394
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+137
-394
tests/test_fastnn/test_softmax.py
tests/test_fastnn/test_softmax.py
+34
-0
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
a65d5009
This diff is collapsed.
Click to expand it.
tests/test_fastnn/test_softmax.py
0 → 100644
View file @
a65d5009
import
torch
from
fastfold.model.fastnn.kernel
import
softmax
def
test_softmax
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
sample_input
=
torch
.
rand
(
shape
).
to
(
device
=
test_device
,
dtype
=
dtype
).
requires_grad_
(
True
)
sample_input_fastnn
=
torch
.
clone
(
sample_input
.
detach
()).
requires_grad_
(
True
)
# Forward
torch_out
=
torch
.
nn
.
functional
.
softmax
(
sample_input
,
dim
=-
1
)
fastnn_out
=
softmax
(
sample_input_fastnn
)
forward_error
=
torch
.
max
(
torch
.
abs
(
torch_out
-
fastnn_out
)).
cpu
().
item
()
assert
forward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
# Backward
out_grad
=
torch
.
rand_like
(
torch_out
).
requires_grad_
(
False
)
torch_out
.
backward
(
out_grad
)
fastnn_out
.
backward
(
out_grad
)
backward_error
=
torch
.
max
(
torch
.
abs
(
sample_input
.
grad
-
sample_input_fastnn
.
grad
)).
cpu
().
item
()
assert
backward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment