Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
16d10d6a
Unverified
Commit
16d10d6a
authored
Mar 09, 2022
by
shenggan
Committed by
GitHub
Mar 09, 2022
Browse files
Merge pull request #6 from hpcaitech/inject_openfold
add inject_openfold
parents
90019096
8e75ab95
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
938 additions
and
347 deletions
+938
-347
fastfold/distributed/core.py
fastfold/distributed/core.py
+4
-0
fastfold/model/kernel/cuda_native/csrc/layer_norm_cuda.cpp
fastfold/model/kernel/cuda_native/csrc/layer_norm_cuda.cpp
+5
-0
fastfold/model/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
...d/model/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
+96
-70
fastfold/model/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...fold/model/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+387
-258
fastfold/model/kernel/jit/fused_ops.py
fastfold/model/kernel/jit/fused_ops.py
+5
-5
fastfold/model/msa.py
fastfold/model/msa.py
+1
-1
fastfold/model/triangle.py
fastfold/model/triangle.py
+32
-13
fastfold/utils/__init__.py
fastfold/utils/__init__.py
+3
-0
fastfold/utils/inject_openfold.py
fastfold/utils/inject_openfold.py
+182
-0
inference.py
inference.py
+223
-0
No files found.
fastfold/distributed/core.py
View file @
16d10d6a
...
...
@@ -74,6 +74,8 @@ def get_data_parallel_group():
def
get_tensor_model_parallel_world_size
():
if
not
dap_is_initialized
():
return
1
"""Return world size for the tensor model parallel group."""
global
_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if
_TENSOR_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
...
...
@@ -82,6 +84,8 @@ def get_tensor_model_parallel_world_size():
def
get_tensor_model_parallel_rank
():
if
not
dap_is_initialized
():
return
0
"""Return my rank for the tensor model parallel group."""
global
_TENSOR_MODEL_PARALLEL_RANK
if
_TENSOR_MODEL_PARALLEL_RANK
is
not
None
:
...
...
fastfold/model/kernel/cuda_native/csrc/layer_norm_cuda.cpp
View file @
16d10d6a
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cassert>
#include <vector>
...
...
@@ -74,6 +75,8 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input, at::IntArrayRef norm
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
...
...
@@ -104,6 +107,8 @@ std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor dout, at::Tensor m
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
...
...
fastfold/model/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
View file @
16d10d6a
...
...
@@ -64,15 +64,27 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
buf
[
32
];
float
thread_mean
;
float
thread_m2
;
float
thread_count
;
float
thread_mean
=
0.
f
;
float
thread_m2
=
0.
f
;
float
thread_count
=
0.
f
;
float
warp_mean
;
float
warp_m2
;
...
...
@@ -81,13 +93,13 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
...
...
@@ -102,16 +114,17 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
invvar
[
row_offset
]
=
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
buf
[
i
]
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
beta
[
lane_id
*
cols_per_thread
+
i
];
}
}
}
__global__
void
fastfold_layernorm_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
...
...
@@ -120,15 +133,27 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
buf
[
32
];
float
thread_mean
;
float
thread_m2
;
float
thread_count
;
float
thread_mean
=
0.
f
;
float
thread_m2
=
0.
f
;
float
thread_count
=
0.
f
;
float
warp_mean
;
float
warp_m2
;
...
...
@@ -137,13 +162,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
...
...
@@ -158,23 +183,24 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
invvar
[
row_offset
]
=
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
buf
[
i
])
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
beta
[
lane_id
*
cols_per_thread
+
i
];
}
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
rows
,
int
cols
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
->
dtype
()
==
torch
::
kFloat32
)
{
...
...
fastfold/model/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
16d10d6a
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
...
...
@@ -33,41 +34,53 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
int
rows
,
...
...
@@ -75,42 +88,55 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
__global__
void
fastfold_softmax_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
int
rows
,
...
...
@@ -118,37 +144,49 @@ __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
];
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_this_thread
+
i
]
=
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
];
}
}
}
...
...
@@ -157,46 +195,60 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
}
}
at
::
Tensor
softmax
(
at
::
Tensor
input
,
int
rows
,
int
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
@@ -212,9 +264,10 @@ at::Tensor softmax(at::Tensor input, int rows, int cols) {
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
int
rows
,
int
cols
)
{
CHECK_INPUT
(
output
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
@@ -237,18 +290,29 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
else
{
...
...
@@ -257,25 +321,26 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
...
...
@@ -284,18 +349,29 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
else
{
...
...
@@ -304,36 +380,38 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
int
rows
,
int
cols
,
float
scale
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
@@ -355,13 +433,24 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
...
...
@@ -369,23 +458,23 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
...
...
@@ -393,6 +482,7 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
row_d_input
=
0
;
}
}
}
}
__global__
void
fastfold_softmax_scale_mask_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
...
...
@@ -401,13 +491,24 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
...
...
@@ -415,23 +516,23 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
...
...
@@ -439,16 +540,18 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
row_d_input
=
0
;
}
}
}
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
int
rows
,
int
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
int
head
=
output
.
sizes
()[
2
];
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
@@ -473,19 +576,30 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
else
{
...
...
@@ -495,25 +609,26 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bias_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
...
...
@@ -522,19 +637,30 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
cols
/
32
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
at
::
BFloat16
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
else
{
...
...
@@ -544,26 +670,27 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
at
::
Tensor
fused_scale_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
...
...
@@ -571,10 +698,11 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
bias
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
@@ -596,10 +724,11 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
int
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
int
head
=
output
.
sizes
()[
2
];
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
rows
/
4
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
...
...
fastfold/model/kernel/jit/fused_ops.py
View file @
16d10d6a
...
...
@@ -9,14 +9,14 @@ def bias_sigmod_ele(y, bias, z):
@
torch
.
jit
.
script
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
True
)
residual
:
torch
.
Tensor
,
prob
:
float
,
training
:
bool
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
@
torch
.
jit
.
script
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
,
training
:
bool
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
training
)
*
(
g
*
(
ab
+
b
))
fastfold/model/msa.py
View file @
16d10d6a
...
...
@@ -50,7 +50,7 @@ class MSARowAttentionWithPairBias(nn.Module):
M
=
self
.
attention
(
M
,
M_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:],
device
=
M
.
device
,
dtype
=
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
MSAColumnAttention
(
nn
.
Module
):
...
...
fastfold/model/triangle.py
View file @
16d10d6a
...
...
@@ -65,7 +65,8 @@ class TriangleMultiplicationOutgoing(nn.Module):
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
...
...
@@ -103,10 +104,7 @@ class TriangleMultiplicationIncoming(nn.Module):
left_proj_act
=
gather_async_opp
(
left_proj_act
,
work
,
dim
=
2
)
p
=
torch
.
matmul
(
permute_final_dims
(
left_proj_act
,
(
2
,
1
,
0
)),
right_proj_act
)
p
=
torch
.
matmul
(
permute_final_dims
(
left_proj_act
,
(
2
,
1
,
0
)),
right_proj_act
)
ab
=
permute_final_dims
(
p
,
(
1
,
2
,
0
))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
...
...
@@ -117,7 +115,8 @@ class TriangleMultiplicationIncoming(nn.Module):
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
...
...
@@ -156,7 +155,12 @@ class TriangleAttentionStartingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
...
...
@@ -197,7 +201,12 @@ class TriangleAttentionEndingNode(nn.Module):
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
PairStack
(
nn
.
Module
):
...
...
@@ -209,10 +218,20 @@ class PairStack(nn.Module):
self
.
n_head
=
4
self
.
hidden_c
=
int
(
d_pair
/
self
.
n_head
)
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
,
c
=
d_pair
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
,
c
=
d_pair
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
,
c
=
self
.
hidden_c
,
n_head
=
self
.
n_head
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
,
c
=
self
.
hidden_c
,
n_head
=
self
.
n_head
)
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
,
c
=
d_pair
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
,
c
=
d_pair
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
,
c
=
self
.
hidden_c
,
n_head
=
self
.
n_head
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
,
c
=
self
.
hidden_c
,
n_head
=
self
.
n_head
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
,
pair_mask
):
...
...
fastfold/utils/__init__.py
0 → 100644
View file @
16d10d6a
from
.inject_openfold
import
inject_openfold
__all__
=
[
'inject_openfold'
]
\ No newline at end of file
fastfold/utils/inject_openfold.py
0 → 100644
View file @
16d10d6a
from
typing
import
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
fastfold.model
import
MSAStack
,
OutProductMean
,
PairStack
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
from
fastfold.distributed.comm
import
gather
,
scatter
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
m
=
scatter
(
m
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
m
=
gather
(
m
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
return
m
,
z
def
copy_layernorm
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_linear
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
if
model_fast
.
use_bias
:
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_qkv_linear
(
model_fast
,
ori_q
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_q
.
weight
,
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_attention
(
model_fast
,
model_ori
):
copy_qkv_linear
(
model_fast
.
to_qkv
,
model_ori
.
linear_q
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_left_right
(
model_fast
,
ori_left
,
ori_right
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_left
.
weight
,
ori_right
.
weight
),
dim
=
0
))
model_fast
.
bias
.
copy_
(
torch
.
cat
((
ori_left
.
bias
,
ori_right
.
bias
),
dim
=
0
))
def
copy_transition
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
norm
,
model_ori
.
layer_norm
)
copy_linear
(
model_fast
.
linear1
,
model_ori
.
linear_1
)
copy_linear
(
model_fast
.
linear2
,
model_ori
.
linear_2
)
def
copy_triangle
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
def
copy_triangle_att
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm
)
copy_linear
(
model_fast
.
linear_b
,
model_ori
.
linear
)
copy_attention
(
model_fast
.
attention
,
model_ori
.
mha
)
model_fast
.
out_bias
.
copy_
(
model_ori
.
mha
.
linear_o
.
bias
)
def
copy_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
)
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
)
copy_attention
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
_msa_att
.
layer_norm_m
)
copy_attention
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
attention
,
block_ori
.
msa_att_col
.
_msa_att
.
mha
)
# MSATransition
copy_transition
(
block_fast
.
msa_stack
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
block_ori
.
core
.
outer_product_mean
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
block_ori
.
core
.
outer_product_mean
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
block_ori
.
core
.
outer_product_mean
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
block_ori
.
core
.
outer_product_mean
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair_stack
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
inject_openfold
(
model
):
with
torch
.
no_grad
():
fastfold_blocks
=
nn
.
ModuleList
()
for
block_id
,
openfold_block
in
enumerate
(
model
.
evoformer
.
blocks
):
c_m
=
openfold_block
.
msa_att_row
.
c_in
c_z
=
openfold_block
.
msa_att_row
.
c_z
fastfold_block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
))
copy_para
(
fastfold_block
,
openfold_block
)
fastfold_blocks
.
append
(
fastfold_block
)
model
.
evoformer
.
blocks
=
fastfold_blocks
return
model
inference.py
0 → 100644
View file @
16d10d6a
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
logging
import
os
import
random
import
sys
import
time
from
datetime
import
date
import
numpy
as
np
import
torch
import
openfold.np.relax.relax
as
relax
from
fastfold.utils
import
inject_openfold
from
openfold.config
import
model_config
from
openfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
protein
,
residue_constants
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
scripts.utils
import
add_data_args
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
model
=
inject_openfold
(
model
)
model
=
model
.
eval
()
#script_preset_(model)
model
=
model
.
to
(
args
.
model_device
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
if
(
args
.
use_precomputed_alignments
is
None
):
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
alignment_dir
=
args
.
use_precomputed_alignments
# Gather input sequences
with
open
(
args
.
fasta_path
,
"r"
)
as
fp
:
lines
=
[
l
.
strip
()
for
l
in
fp
.
readlines
()]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
l
[
1
:]
for
l
in
tags
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
print
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
batch
.
items
()}
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
)
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
if
visible_devices
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
)
parser
.
add_argument
(
"template_mmcif_dir"
,
type
=
str
,
)
parser
.
add_argument
(
"--use_precomputed_alignments"
,
type
=
str
,
default
=
None
,
help
=
"""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored."""
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
)
parser
.
add_argument
(
"--model_device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
parser
.
add_argument
(
"--param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
12
,
help
=
"""Number of CPUs with which to run alignment tools"""
)
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'reduced_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
))
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
param_path
is
None
):
args
.
param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
model_name
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
logging
.
warning
(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment