Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
ded582b2
"...composable_kernel_onnx.git" did not exist on "8669e242ad424187bd3818128f8570e359c66903"
Unverified
Commit
ded582b2
authored
Jul 27, 2022
by
shenggan
Committed by
GitHub
Jul 27, 2022
Browse files
remove scale in fused softmax kernel (#34)
parent
ad7f0cb5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
138 deletions
+116
-138
fastfold/model/fastnn/kernel/__init__.py
fastfold/model/fastnn/kernel/__init__.py
+2
-2
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
...old/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
+12
-15
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+84
-92
fastfold/model/fastnn/kernel/cuda_native/softmax.py
fastfold/model/fastnn/kernel/cuda_native/softmax.py
+14
-16
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+4
-13
No files found.
fastfold/model/fastnn/kernel/__init__.py
View file @
ded582b2
from
.jit.fused_ops
import
bias_dropout_add
,
bias_sigmod_ele
,
bias_ele_dropout_residual
from
.jit.fused_ops
import
bias_dropout_add
,
bias_sigmod_ele
,
bias_ele_dropout_residual
from
.cuda_native.layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.cuda_native.layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.cuda_native.softmax
import
softmax
,
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
.cuda_native.softmax
import
softmax
,
mask_softmax
,
mask_bias_softmax
__all__
=
[
__all__
=
[
"bias_dropout_add"
,
"bias_sigmod_ele"
,
"bias_ele_dropout_residual"
,
"LayerNorm"
,
"softmax"
,
"bias_dropout_add"
,
"bias_sigmod_ele"
,
"bias_ele_dropout_residual"
,
"LayerNorm"
,
"softmax"
,
"
scale_
mask_softmax"
,
"
scale_
mask_bias_softmax"
"mask_softmax"
,
"mask_bias_softmax"
]
]
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
View file @
ded582b2
...
@@ -3,28 +3,25 @@
...
@@ -3,28 +3,25 @@
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_
scale_
mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
float
scale
);
long
long
cols
);
at
::
Tensor
fused_
scale_
mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
);
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_scale_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
);
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
long
long
cols
,
float
scale
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax
,
"Softmax forward (CUDA)"
);
m
.
def
(
"forward"
,
&
softmax
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_gradient
,
"Softmax backward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_gradient
,
"Softmax backward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_forward"
,
&
fused_scale_mask_softmax_forward
,
m
.
def
(
"fused_mask_softmax_forward"
,
&
fused_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_backward"
,
&
fused_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_backward"
,
&
fused_scale_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_forward"
,
&
fused_
scale_
mask_bias_softmax_forward
,
m
.
def
(
"fused_mask_bias_softmax_forward"
,
&
fused_mask_bias_softmax_forward
,
"Softmax forward (CUDA)"
);
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_backward"
,
&
fused_
scale_
mask_bias_softmax_backward
,
m
.
def
(
"fused_mask_bias_softmax_backward"
,
&
fused_mask_bias_softmax_backward
,
"Softmax forward (CUDA)"
);
"Softmax forward (CUDA)"
);
}
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
ded582b2
...
@@ -330,8 +330,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
...
@@ -330,8 +330,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
////////////////
template
<
typename
T
,
int
cols_per_thread
>
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_
scale_
mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
__global__
void
fastfold_softmax_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
long
long
cols
,
float
scale
,
int
head
)
{
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -349,7 +349,7 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
...
@@ -349,7 +349,7 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
}
}
}
...
@@ -376,8 +376,8 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
...
@@ -376,8 +376,8 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
}
}
template
<
typename
T
,
int
block_size
>
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_
scale_
mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
__global__
void
fastfold_softmax_mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
long
long
cols
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
...
@@ -389,7 +389,7 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
...
@@ -389,7 +389,7 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
if
(
mask_ptr
[
id
]
==
0
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
;
buf
[
id
]
=
input
[
row
*
cols
+
id
];
}
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
}
...
@@ -410,8 +410,8 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
...
@@ -410,8 +410,8 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
}
}
}
}
at
::
Tensor
fused_
scale_
mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
...
@@ -423,33 +423,33 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
...
@@ -423,33 +423,33 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
if
(
cols
<=
32
)
{
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask
<
float
,
1
>
fastfold_softmax_mask
<
float
,
1
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask
<
at
::
Half
,
1
>
fastfold_softmax_mask
<
at
::
Half
,
1
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask
<
at
::
BFloat16
,
1
>
fastfold_softmax_mask
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
}
}
#define COLS_CASE(col_per_thread) \
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_
scale_
mask<float, col_per_thread> \
fastfold_softmax_mask<float, col_per_thread>
\
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols,
scale, head);
\
(float *)input.data_ptr(), rows, cols,
head);
\
} else if (input.dtype() == torch::kFloat16) { \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_
scale_
mask<at::Half, col_per_thread> \
fastfold_softmax_mask<at::Half, col_per_thread>
\
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols,
scale, head);
\
(at::Half *)input.data_ptr(), rows, cols,
head);
\
} else if (input.dtype() == torch::kBFloat16) { \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_
scale_
mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
fastfold_softmax_mask<at::BFloat16, col_per_thread><<<grid, block>>>(
\
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols,
scale, head);
\
(at::BFloat16 *)input.data_ptr(), rows, cols,
head);
\
} \
} \
}
}
COLS_CASE
(
2
)
COLS_CASE
(
2
)
...
@@ -493,26 +493,25 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
...
@@ -493,26 +493,25 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
const
size_t
smem
=
cols
*
sizeof
(
float
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_sm
<
float
,
128
>
fastfold_softmax_mask_sm
<
float
,
128
>
<<<
grid
,
block
,
smem
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
,
smem
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_sm
<
at
::
Half
,
128
>
fastfold_softmax_mask_sm
<
at
::
Half
,
128
>
<<<
grid
,
block
,
smem
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
,
smem
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
}
}
return
input
;
return
input
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
__global__
void
fastfold_softmax_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -559,7 +558,7 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
...
@@ -559,7 +558,7 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
])
)
;
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
}
else
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
0
;
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
0
;
}
}
...
@@ -567,9 +566,8 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
...
@@ -567,9 +566,8 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
}
}
}
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
long
long
rows
,
long
long
cols
)
{
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
@@ -580,19 +578,18 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
...
@@ -580,19 +578,18 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_grad
<
float
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
}
return
grad_input
;
return
grad_input
;
...
@@ -601,9 +598,8 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
...
@@ -601,9 +598,8 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
////////////////
template
<
typename
T
,
int
cols_per_thread
>
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
__global__
void
fastfold_softmax_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
rows
,
long
long
cols
,
float
scale
,
long
long
cols
,
int
head
)
{
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -622,8 +618,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
...
@@ -622,8 +618,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
+
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
}
}
...
@@ -650,9 +646,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
...
@@ -650,9 +646,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
}
}
template
<
typename
T
,
int
block_size
>
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_scale_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
__global__
void
fastfold_softmax_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
rows
,
long
long
cols
,
float
scale
,
long
long
cols
,
int
head
)
{
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
...
@@ -665,7 +660,7 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
...
@@ -665,7 +660,7 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
if
(
mask_ptr
[
id
]
==
0
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
+
bias_ptr
[
id
];
buf
[
id
]
=
input
[
row
*
cols
+
id
]
+
bias_ptr
[
id
];
}
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
}
...
@@ -686,8 +681,8 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
...
@@ -686,8 +681,8 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
}
}
}
}
at
::
Tensor
fused_
scale_
mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
bias
);
CHECK_INPUT
(
bias
);
...
@@ -700,37 +695,36 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
...
@@ -700,37 +695,36 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
if
(
cols
<=
32
)
{
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_bias
<
at
::
BFloat16
,
1
>
fastfold_softmax_mask_bias
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
rows
,
cols
,
head
);
}
}
}
}
#define COLS_CASE(col_per_thread) \
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask_bias<float, col_per_thread><<<grid, block>>>( \
fastfold_softmax_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask_bias<at::Half, col_per_thread> \
fastfold_softmax_mask_bias<at::Half, col_per_thread><<<grid, block>>>( \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head); \
cols, scale, head); \
} else if (input.dtype() == torch::kBFloat16) { \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
fastfold_softmax_scale_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
head); \
scale, head); \
} \
} \
}
}
COLS_CASE
(
2
)
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
3
)
...
@@ -773,27 +767,26 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
...
@@ -773,27 +767,26 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
const
size_t
smem
=
cols
*
sizeof
(
float
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
}
}
}
return
input
;
return
input
;
}
}
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
)
{
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
@@ -804,19 +797,18 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
...
@@ -804,19 +797,18 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_grad
<
float
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
}
return
grad_input
;
return
grad_input
;
...
...
fastfold/model/fastnn/kernel/cuda_native/softmax.py
View file @
ded582b2
...
@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
...
@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
return
grad_input
return
grad_input
class
Fused
Scale
MaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
scale
):
def
forward
(
ctx
,
input
,
mask
):
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
mask_
=
mask
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_forward
(
output
=
fastfold_softmax_cuda
.
fused_mask_softmax_forward
(
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
)
ctx
.
save_for_backward
(
output
,
mask_
)
ctx
.
scale
=
scale
return
output
return
output
...
@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
...
@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
output
,
mask_
=
ctx
.
saved_tensors
output
,
mask_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_backward
(
grad_input
=
fastfold_softmax_cuda
.
fused_mask_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
return
grad_input
.
contiguous
(),
None
,
None
return
grad_input
.
contiguous
(),
None
,
None
class
Fused
Scale
MaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
bias
,
scale
):
def
forward
(
ctx
,
input
,
mask
,
bias
):
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
mask_
=
mask
.
contiguous
()
bias_
=
bias
.
contiguous
()
bias_
=
bias
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_forward
(
output
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
,
bias_
)
ctx
.
save_for_backward
(
output
,
mask_
,
bias_
)
ctx
.
scale
=
scale
return
output
return
output
...
@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
...
@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
output
,
mask_
,
bias_
=
ctx
.
saved_tensors
output
,
mask_
,
bias_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_backward
(
grad_input
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
grad_input
=
grad_input
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
...
@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
...
@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
softmax
=
SoftmaxAffineFunction
.
apply
softmax
=
SoftmaxAffineFunction
.
apply
scale_
mask_softmax
=
Fused
Scale
MaskSoftmaxFunction
.
apply
mask_softmax
=
FusedMaskSoftmaxFunction
.
apply
scale_
mask_bias_softmax
=
Fused
Scale
MaskBiasSoftmaxFunction
.
apply
mask_bias_softmax
=
FusedMaskBiasSoftmaxFunction
.
apply
fastfold/model/fastnn/ops.py
View file @
ded582b2
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
fastfold.model.fastnn.kernel
import
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
mask_softmax
,
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
.initializer
import
glorot_uniform_af
from
.initializer
import
glorot_uniform_af
...
@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
...
@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
# q = self.to_q(in_data)
q
=
q
*
self
.
scaling
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
# logits += mask
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
is
not
None
:
# logits += nonbatched_bias.unsqueeze(1)
# logits += nonbatched_bias.unsqueeze(1)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
weights
=
scale_
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
)
,
self
.
scaling
)
weights
=
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
else
:
weights
=
scale_mask_softmax
(
logits
,
mask
,
self
.
scaling
)
weights
=
mask_softmax
(
logits
,
mask
)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment