Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
ded582b2
Unverified
Commit
ded582b2
authored
Jul 27, 2022
by
shenggan
Committed by
GitHub
Jul 27, 2022
Browse files
remove scale in fused softmax kernel (#34)
parent
ad7f0cb5
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
138 deletions
+116
-138
fastfold/model/fastnn/kernel/__init__.py
fastfold/model/fastnn/kernel/__init__.py
+2
-2
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
...old/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
+12
-15
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+84
-92
fastfold/model/fastnn/kernel/cuda_native/softmax.py
fastfold/model/fastnn/kernel/cuda_native/softmax.py
+14
-16
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+4
-13
No files found.
fastfold/model/fastnn/kernel/__init__.py
View file @
ded582b2
from
.jit.fused_ops
import
bias_dropout_add
,
bias_sigmod_ele
,
bias_ele_dropout_residual
from
.cuda_native.layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.cuda_native.softmax
import
softmax
,
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
.cuda_native.softmax
import
softmax
,
mask_softmax
,
mask_bias_softmax
__all__
=
[
"bias_dropout_add"
,
"bias_sigmod_ele"
,
"bias_ele_dropout_residual"
,
"LayerNorm"
,
"softmax"
,
"
scale_
mask_softmax"
,
"
scale_
mask_bias_softmax"
"mask_softmax"
,
"mask_bias_softmax"
]
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
View file @
ded582b2
...
...
@@ -3,28 +3,25 @@
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_
scale_
mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_
scale_
mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_scale_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_gradient
,
"Softmax backward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_forward"
,
&
fused_scale_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_backward"
,
&
fused_scale_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_forward"
,
&
fused_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_backward"
,
&
fused_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_forward"
,
&
fused_
scale_
mask_bias_softmax_forward
,
m
.
def
(
"fused_mask_bias_softmax_forward"
,
&
fused_mask_bias_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_backward"
,
&
fused_
scale_
mask_bias_softmax_backward
,
m
.
def
(
"fused_mask_bias_softmax_backward"
,
&
fused_mask_bias_softmax_backward
,
"Softmax forward (CUDA)"
);
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
ded582b2
This diff is collapsed.
Click to expand it.
fastfold/model/fastnn/kernel/cuda_native/softmax.py
View file @
ded582b2
...
...
@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
return
grad_input
class
Fused
Scale
MaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
scale
):
def
forward
(
ctx
,
input
,
mask
):
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_forward
(
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
output
=
fastfold_softmax_cuda
.
fused_mask_softmax_forward
(
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
)
ctx
.
scale
=
scale
return
output
...
...
@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
output
,
mask_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_input
=
fastfold_softmax_cuda
.
fused_mask_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
return
grad_input
.
contiguous
(),
None
,
None
class
Fused
Scale
MaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
bias
,
scale
):
def
forward
(
ctx
,
input
,
mask
,
bias
):
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
bias_
=
bias
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
output
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
,
bias_
)
ctx
.
scale
=
scale
return
output
...
...
@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
output
,
mask_
,
bias_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_input
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
grad_input
=
grad_input
.
contiguous
()
...
...
@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
softmax
=
SoftmaxAffineFunction
.
apply
scale_
mask_softmax
=
Fused
Scale
MaskSoftmaxFunction
.
apply
scale_
mask_bias_softmax
=
Fused
Scale
MaskBiasSoftmaxFunction
.
apply
mask_softmax
=
FusedMaskSoftmaxFunction
.
apply
mask_bias_softmax
=
FusedMaskBiasSoftmaxFunction
.
apply
fastfold/model/fastnn/ops.py
View file @
ded582b2
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
fastfold.model.fastnn.kernel
import
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
mask_softmax
,
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
.initializer
import
glorot_uniform_af
...
...
@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
# q = self.to_q(in_data)
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
# logits += mask
if
nonbatched_bias
is
not
None
:
# logits += nonbatched_bias.unsqueeze(1)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
weights
=
scale_
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
)
,
self
.
scaling
)
weights
=
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
scale_mask_softmax
(
logits
,
mask
,
self
.
scaling
)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weights
=
mask_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment