Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
ded582b2
Unverified
Commit
ded582b2
authored
Jul 27, 2022
by
shenggan
Committed by
GitHub
Jul 27, 2022
Browse files
remove scale in fused softmax kernel (#34)
parent
ad7f0cb5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
138 deletions
+116
-138
fastfold/model/fastnn/kernel/__init__.py
fastfold/model/fastnn/kernel/__init__.py
+2
-2
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
...old/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
+12
-15
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+84
-92
fastfold/model/fastnn/kernel/cuda_native/softmax.py
fastfold/model/fastnn/kernel/cuda_native/softmax.py
+14
-16
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+4
-13
No files found.
fastfold/model/fastnn/kernel/__init__.py
View file @
ded582b2
from
.jit.fused_ops
import
bias_dropout_add
,
bias_sigmod_ele
,
bias_ele_dropout_residual
from
.cuda_native.layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.cuda_native.softmax
import
softmax
,
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
.cuda_native.softmax
import
softmax
,
mask_softmax
,
mask_bias_softmax
__all__
=
[
"bias_dropout_add"
,
"bias_sigmod_ele"
,
"bias_ele_dropout_residual"
,
"LayerNorm"
,
"softmax"
,
"
scale_
mask_softmax"
,
"
scale_
mask_bias_softmax"
"mask_softmax"
,
"mask_bias_softmax"
]
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
View file @
ded582b2
...
...
@@ -3,28 +3,25 @@
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_
scale_
mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_
scale_
mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_scale_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
);
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_gradient
,
"Softmax backward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_forward"
,
&
fused_scale_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_scale_mask_softmax_backward"
,
&
fused_scale_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_forward"
,
&
fused_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_backward"
,
&
fused_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_forward"
,
&
fused_
scale_
mask_bias_softmax_forward
,
m
.
def
(
"fused_mask_bias_softmax_forward"
,
&
fused_mask_bias_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_
scale_
mask_bias_softmax_backward"
,
&
fused_
scale_
mask_bias_softmax_backward
,
m
.
def
(
"fused_mask_bias_softmax_backward"
,
&
fused_mask_bias_softmax_backward
,
"Softmax forward (CUDA)"
);
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
ded582b2
...
...
@@ -330,8 +330,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_
scale_
mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
__global__
void
fastfold_softmax_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -349,7 +349,7 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
}
...
...
@@ -376,8 +376,8 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_
scale_
mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
__global__
void
fastfold_softmax_mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
...
...
@@ -389,7 +389,7 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
;
buf
[
id
]
=
input
[
row
*
cols
+
id
];
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
...
...
@@ -410,8 +410,8 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
}
}
at
::
Tensor
fused_
scale_
mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
...
...
@@ -423,33 +423,33 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask
<
float
,
1
>
fastfold_softmax_mask
<
float
,
1
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask
<
at
::
Half
,
1
>
fastfold_softmax_mask
<
at
::
Half
,
1
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask
<
at
::
BFloat16
,
1
>
fastfold_softmax_mask
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_
scale_
mask<float, col_per_thread> \
fastfold_softmax_mask<float, col_per_thread>
\
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols,
scale, head);
\
(float *)input.data_ptr(), rows, cols,
head);
\
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_
scale_
mask<at::Half, col_per_thread> \
fastfold_softmax_mask<at::Half, col_per_thread>
\
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols,
scale, head);
\
(at::Half *)input.data_ptr(), rows, cols,
head);
\
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_
scale_
mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
fastfold_softmax_mask<at::BFloat16, col_per_thread><<<grid, block>>>(
\
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols,
scale, head);
\
(at::BFloat16 *)input.data_ptr(), rows, cols,
head);
\
} \
}
COLS_CASE
(
2
)
...
...
@@ -493,26 +493,25 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_sm
<
float
,
128
>
fastfold_softmax_mask_sm
<
float
,
128
>
<<<
grid
,
block
,
smem
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_sm
<
at
::
Half
,
128
>
fastfold_softmax_mask_sm
<
at
::
Half
,
128
>
<<<
grid
,
block
,
smem
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
return
input
;
}
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
__global__
void
fastfold_softmax_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -559,7 +558,7 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
])
)
;
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
0
;
}
...
...
@@ -567,9 +566,8 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
}
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
...
@@ -580,19 +578,18 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_grad
<
float
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
return
grad_input
;
...
...
@@ -601,9 +598,8 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
__global__
void
fastfold_softmax_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -622,8 +618,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
+
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
...
...
@@ -650,9 +646,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_scale_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
__global__
void
fastfold_softmax_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
...
...
@@ -665,7 +660,7 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
+
bias_ptr
[
id
];
buf
[
id
]
=
input
[
row
*
cols
+
id
]
+
bias_ptr
[
id
];
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
...
...
@@ -686,8 +681,8 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
}
}
at
::
Tensor
fused_
scale_
mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
bias
);
...
...
@@ -700,37 +695,36 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_bias
<
at
::
BFloat16
,
1
>
fastfold_softmax_mask_bias
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
rows
,
cols
,
head
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask_bias<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, \
cols, scale, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_scale_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
scale, head); \
} \
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask_bias<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
head); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
...
...
@@ -773,27 +767,26 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
fastfold_softmax_mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
}
return
input
;
}
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
...
@@ -804,19 +797,18 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_
scale_
mask_grad
<
float
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_
scale_
mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
return
grad_input
;
...
...
fastfold/model/fastnn/kernel/cuda_native/softmax.py
View file @
ded582b2
...
...
@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
return
grad_input
class
Fused
Scale
MaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
scale
):
def
forward
(
ctx
,
input
,
mask
):
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_forward
(
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
output
=
fastfold_softmax_cuda
.
fused_mask_softmax_forward
(
input_
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
)
ctx
.
scale
=
scale
return
output
...
...
@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
output
,
mask_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_input
=
fastfold_softmax_cuda
.
fused_mask_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
ctx
.
rows
,
ctx
.
cols
)
return
grad_input
.
contiguous
(),
None
,
None
class
Fused
Scale
MaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
class
FusedMaskBiasSoftmaxFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
mask
,
bias
,
scale
):
def
forward
(
ctx
,
input
,
mask
,
bias
):
input_
=
input
.
contiguous
()
mask_
=
mask
.
contiguous
()
bias_
=
bias
.
contiguous
()
ctx
.
cols
=
input_
.
shape
[
-
1
]
ctx
.
rows
=
reduce
(
mul
,
input
.
shape
[:
-
1
])
output
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
scale
)
output
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
ctx
.
save_for_backward
(
output
,
mask_
,
bias_
)
ctx
.
scale
=
scale
return
output
...
...
@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
output
,
mask_
,
bias_
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fastfold_softmax_cuda
.
fused_
scale_
mask_bias_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
,
ctx
.
scale
)
grad_input
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_backward
(
grad_output
.
contiguous
(),
output
,
mask_
,
bias_
,
ctx
.
rows
,
ctx
.
cols
)
grad_input
=
grad_input
.
contiguous
()
...
...
@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
softmax
=
SoftmaxAffineFunction
.
apply
scale_
mask_softmax
=
Fused
Scale
MaskSoftmaxFunction
.
apply
scale_
mask_bias_softmax
=
Fused
Scale
MaskBiasSoftmaxFunction
.
apply
mask_softmax
=
FusedMaskSoftmaxFunction
.
apply
mask_bias_softmax
=
FusedMaskBiasSoftmaxFunction
.
apply
fastfold/model/fastnn/ops.py
View file @
ded582b2
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
fastfold.model.fastnn.kernel
import
scale_
mask_softmax
,
scale_
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
mask_softmax
,
mask_bias_softmax
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
.initializer
import
glorot_uniform_af
...
...
@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
# q = self.to_q(in_data)
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
# logits += mask
if
nonbatched_bias
is
not
None
:
# logits += nonbatched_bias.unsqueeze(1)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
weights
=
scale_
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
)
,
self
.
scaling
)
weights
=
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
scale_mask_softmax
(
logits
,
mask
,
self
.
scaling
)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weights
=
mask_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment