Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e4119508
"vscode:/vscode.git/clone" did not exist on "5c75a5fbc421e3126c796822c274b496b6ea71ec"
Commit
e4119508
authored
Feb 14, 2023
by
zhuww
Browse files
fix multimer bug
parent
614e2763
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
358 additions
and
245 deletions
+358
-245
README.md
README.md
+4
-8
fastfold/common/protein.py
fastfold/common/protein.py
+153
-71
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+1
-1
fastfold/data/templates.py
fastfold/data/templates.py
+4
-1
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+12
-12
fastfold/model/fastnn/embedders_multimer.py
fastfold/model/fastnn/embedders_multimer.py
+24
-23
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+2
-2
fastfold/model/fastnn/kernel/layer_norm.py
fastfold/model/fastnn/kernel/layer_norm.py
+18
-0
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+3
-3
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+100
-53
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+30
-61
fastfold/model/nn/dropout.py
fastfold/model/nn/dropout.py
+1
-1
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+0
-5
inference.py
inference.py
+2
-0
setup.py
setup.py
+2
-2
tests/test_fastnn/test_msa_att_col.py
tests/test_fastnn/test_msa_att_col.py
+1
-1
tests/test_fastnn/test_template_embedder.py
tests/test_fastnn/test_template_embedder.py
+1
-1
No files found.
README.md
View file @
e4119508
...
@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
...
@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
## Installation
## Installation
To install
and use
FastFold, you will need:
To install FastFold, you will need:
+
Python 3.8 or 3.9.
+
Python 3.8 or 3.9.
+
[
NVIDIA CUDA
](
https://developer.nvidia.com/cuda-downloads
)
11.1 or above
+
[
NVIDIA CUDA
](
https://developer.nvidia.com/cuda-downloads
)
11.1 or above
+
PyTorch 1.1
0
or above
+
PyTorch 1.1
2
or above
For now, You can install FastFold:
For now, You can install FastFold:
...
@@ -45,14 +45,10 @@ python setup.py install
...
@@ -45,14 +45,10 @@ python setup.py install
#### Advanced
#### Advanced
To leverage the power of FastFold, we recommend you build
[
Triton
](
)
from source.
To leverage the power of FastFold, we recommend you to install
[
Triton
](
https://github.com/openai/triton
)
.
**[NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.4 or above is needed.**
```
bash
```
bash
git clone https://github.com/openai/triton.git ~/triton
pip
install
triton
==
2.0.0.dev20221005
cd
~/triton/python
pip
install
-e
.
```
```
...
...
fastfold/common/protein.py
View file @
e4119508
...
@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
...
@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
class
Protein
:
...
@@ -46,11 +49,22 @@ class Protein:
...
@@ -46,11 +49,22 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# representing the displacement of the residue from its ground truth mean
# value.
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
"""Takes a PDB string and constructs a Protein object.
...
@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
Args:
pdb_str: The contents of the pdb file
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
chain_id: If chain_id is specified (e.g. A), then only that chain is
will be parsed). If chain_id is specified (e.g. A), then only that chain
parsed. Else, all chains are parsed.
is parsed.
Returns:
Returns:
A new `Protein` parsed from the pdb contents.
A new `Protein` parsed from the pdb contents.
...
@@ -72,32 +85,33 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -72,32 +85,33 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
models
=
list
(
structure
.
get_models
())
models
=
list
(
structure
.
get_models
())
if
len
(
models
)
!=
1
:
if
len
(
models
)
!=
1
:
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
model
=
models
[
0
]
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
else
:
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
"Only single chain PDBs are supported when chain_id not specified. "
f
"Found
{
len
(
chains
)
}
chains."
)
else
:
chain
=
chains
[
0
]
atom_positions
=
[]
atom_positions
=
[]
aatype
=
[]
aatype
=
[]
atom_mask
=
[]
atom_mask
=
[]
residue_index
=
[]
residue_index
=
[]
chain_ids
=
[]
b_factors
=
[]
b_factors
=
[]
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
raise
ValueError
(
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
restype_idx
=
residue_constants
.
restype_order
.
get
(
residue_constants
.
restype_num
)
res_shortname
,
residue_constants
.
restype_num
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
...
@@ -106,28 +120,40 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -106,28 +120,40 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
continue
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
bfactor
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
# If no known atom positions are reported for the residue then skip it.
continue
continue
aatype
.
append
(
restype_idx
)
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
residue_index
.
append
(
res
.
id
[
1
])
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
b_factors
=
np
.
array
(
b_factors
),
)
)
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
tag_re
=
r
'(\[[A-Z]+\]\n)'
tag_re
=
r
'(\[[A-Z]+\]\n)'
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
atoms
=
[
'N'
,
'CA'
,
'C'
]
atoms
=
[
'N'
,
'CA'
,
'C'
]
...
@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
seq
[
i
]
=
'X'
seq
[
i
]
=
'X'
aatype
=
np
.
array
([
aatype
=
np
.
array
([
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
residue_constants
.
restype_order
.
get
(
for
res_symbol
in
seq
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
])
])
elif
(
"[TERTIARY]"
==
g
[
0
]):
elif
(
"[TERTIARY]"
==
g
[
0
]):
tertiary
=
[]
tertiary
=
[]
for
axis
in
range
(
3
):
for
axis
in
range
(
3
):
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary_np
=
np
.
array
(
tertiary
)
tertiary_np
=
np
.
array
(
tertiary
)
atom_positions
=
np
.
zeros
(
atom_positions
=
np
.
zeros
(
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)).
astype
(
np
.
float32
)
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
for
i
,
atom
in
enumerate
(
atoms
):
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
tertiary_np
[:,
i
::
3
]))
np
.
transpose
(
tertiary_np
[:,
i
::
3
])
)
atom_positions
*=
PICO_TO_ANGSTROM
atom_positions
*=
PICO_TO_ANGSTROM
elif
(
"[MASK]"
==
g
[
0
]):
elif
(
"[MASK]"
==
g
[
0
]):
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
atom_mask
=
np
.
zeros
((
atom_mask
=
np
.
zeros
(
len
(
mask
),
(
len
(
mask
),
residue_constants
.
atom_type_num
,)
residue_constants
.
atom_type_num
,
).
astype
(
np
.
float32
)
)).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
for
i
,
atom
in
enumerate
(
atoms
):
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
*=
mask
[...,
None
]
atom_mask
*=
mask
[...,
None
]
...
@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_index
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
"""Converts a `Protein` instance to a PDB string.
...
@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
...
@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
pdb_lines
.
append
(
"MODEL 1"
)
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
atom_index
=
1
chain_i
d
=
"A"
last_
chain_i
ndex
=
chain_index
[
0
]
# Add all atom sites.
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
b_factors
[
i
]):
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
):
if
mask
<
0.5
:
if
mask
<
0.5
:
continue
continue
...
@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
...
@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
alt_loc
=
""
alt_loc
=
""
insertion_code
=
""
insertion_code
=
""
occupancy
=
1.00
occupancy
=
1.00
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
charge
=
""
charge
=
""
# PDB is a columnar format, every space matters here!
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
atom_line
=
(
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
pdb_lines
.
append
(
atom_line
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
atom_index
+=
1
# Close the chain.
# Close the final chain.
chain_end
=
"TER"
pdb_lines
.
append
(
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
_chain_end
(
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
)
atom_index
,
pdb_lines
.
append
(
chain_termination_line
)
res_1to3
(
aatype
[
-
1
]),
pdb_lines
.
append
(
"ENDMDL"
)
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
...
@@ -258,24 +328,36 @@ def from_prediction(
...
@@ -258,24 +328,36 @@ def from_prediction(
features
:
FeatureDict
,
features
:
FeatureDict
,
result
:
ModelOutput
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
False
,
)
->
Protein
:
)
->
Protein
:
"""Assembles a protein from a prediction.
"""Assembles a protein from a prediction.
Args:
Args:
features: Dictionary holding model inputs.
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
Returns:
A protein instance.
A protein instance.
"""
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
b_factors
=
b_factors
,
)
)
\ No newline at end of file
fastfold/data/data_pipeline.py
View file @
e4119508
...
@@ -249,7 +249,7 @@ def run_msa_tool(
...
@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences
:
Optional
[
int
]
=
None
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
else
:
result
=
msa_runner
.
query
(
fasta_path
)
result
=
msa_runner
.
query
(
fasta_path
)
...
...
fastfold/data/templates.py
View file @
e4119508
...
@@ -866,6 +866,9 @@ def _process_single_hit(
...
@@ -866,6 +866,9 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
_zero_center_positions
=
_zero_center_positions
,
)
)
if
hit
.
sum_probs
is
None
:
features
[
'template_sum_probs'
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# It is possible there were some errors when parsing the other chains in the
...
...
fastfold/distributed/comm.py
View file @
e4119508
...
@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
s
=
1
)
->
Tensor
:
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
_size
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
return
tensor
...
@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
...
@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
tensor_list
=
[]
for
t
in
world_list
:
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
1
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
1
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
1
)
for
i
in
range
(
chunk
s
):
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
dist
.
all_gather
(
list
(
_chunk_list
),
...
@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
...
@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
tensor_list
=
[]
for
t
in
world_list
:
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
0
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
0
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
0
)
for
i
in
range
(
chunk
s
):
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
dist
.
all_gather
(
list
(
_chunk_list
),
...
@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
...
@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
s
:
int
=
None
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
_size
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
input
=
Gather
.
apply
(
input
,
dim
)
else
:
else
:
if
chunk
s
is
None
:
if
chunk
_size
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
input
=
_gather
(
input
,
dim
=
dim
)
else
:
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
s
=
chunk
s
)
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
_size
=
chunk
_size
)
return
input
return
input
...
...
fastfold/model/fastnn/embedders_multimer.py
View file @
e4119508
...
@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
):
):
template_embeds
=
[]
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
template_pair_embeddings
=
torch
.
zeros
((
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
z
.
device
)
for
i
in
range
(
n_templ
):
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
single_template_feats
=
tensor_tree_map
(
...
@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
rigid_vec
.
normalized
()
pair_
act
=
self
.
template_pair_embedder
(
pair_
embedding
=
self
.
template_pair_embedder
(
template_dgram
,
template_dgram
,
aatype_one_hot
,
aatype_one_hot
,
z
,
z
,
...
@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
unit_vector
,
unit_vector
,
)
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
=
template_pair_embeddings
+
self
.
template_pair_stack
(
pair_embedding
,
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
).
squeeze
(
0
)
else
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
+=
self
.
template_pair_stack
.
inplace
(
[
pair_embedding
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
squeeze
(
0
)
single_template_embeds
.
update
(
single_template_embeds
.
update
(
self
.
template_single_embedder
(
self
.
template_single_embedder
(
single_template_feats
,
single_template_feats
,
...
@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
,
template_embeds
,
)
)
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
else
:
template_embeds
[
"template_pair_embedding"
]
=
[
template_embeds
[
"template_pair_embedding"
]]
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
.
inplace
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
to
(
z
.
device
)
# [*, N, N, C_z]
# [*, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
torch
.
sum
(
template_embeds
[
"template_pair_embedding"
],
dim
=-
4
)
/
n_templ
template_pair_embeddings
=
template_pair_embeddings
/
n_templ
template_embeds
[
"template_pair_embedding"
]
=
torch
.
nn
.
functional
.
relu
(
template_embeds
[
"template_pair_embedding"
])
template_pair_embeddings
=
torch
.
nn
.
functional
.
relu
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
self
.
linear_t
(
template_embeds
[
"template_pair_embedding"
])
template_pair_embeddings
=
self
.
linear_t
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
template_pair_embeddings
return
template_embeds
return
template_embeds
fastfold/model/fastnn/evoformer.py
View file @
e4119508
...
@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
...
@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
if
self
.
is_multimer
:
if
self
.
is_multimer
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
else
:
else
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
s
=
4
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/kernel/layer_norm.py
View file @
e4119508
...
@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
...
@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
len
(
input
.
shape
)
>=
3
and
input
.
shape
[
-
3
]
>
4000
:
out
=
torch
.
empty_like
(
input
)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size
=
min
(
4000
*
4000
//
input
.
shape
[
-
3
],
(
input
.
shape
[
-
3
]
+
1
)
//
2
)
if
len
(
input
.
shape
)
==
3
:
for
i
in
range
(
input
.
shape
[
-
3
]):
out
[
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
i
:
i
+
chunk_size
])
elif
len
(
input
.
shape
)
==
4
:
for
j
in
range
(
input
.
shape
[
-
4
]):
for
i
in
range
(
0
,
input
.
shape
[
-
3
],
chunk_size
):
out
[
j
,
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
j
,
i
:
i
+
chunk_size
])
else
:
raise
RuntimeError
(
"Shape"
+
input
.
shape
+
"not implemented for layernorm yet!"
)
return
out
else
:
return
self
.
kernel_forward
(
input
)
def
kernel_forward
(
self
,
input
):
if
_triton_available
:
if
_triton_available
:
return
LayerNormTritonFunc
.
apply
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
return
LayerNormTritonFunc
.
apply
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
self
.
eps
)
...
...
fastfold/model/fastnn/msa.py
View file @
e4119508
...
@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
s
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e4119508
...
@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
...
@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
def
forward
(
self
,
src
):
para_dim
=
src
.
shape
[
1
]
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
self
.
norm
(
src
)
out
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
out
)))
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
src
.
shape
[
1
]
out
=
torch
.
empty_like
(
src
)
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
if
DEBUG
and
ax
>
10
:
...
@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
...
@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
right_act_all
=
M_mask
*
right_act_all
if
CHUNK_SIZE
==
None
:
out
=
torch
.
einsum
(
'bsid, bsje->bijde'
,
left_act
,
right_act_all
)
out
=
rearrange
(
out
,
'b i j d e -> b i j (d e)'
)
out
=
self
.
o_linear
(
out
)
Z
=
out
/
norm
else
:
para_dim
=
left_act
.
shape
[
2
]
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
...
@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
...
@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
bias
=
nonbatched_bias
[
0
]
bias
=
nonbatched_bias
[
0
]
...
@@ -306,6 +303,32 @@ class SelfAttention(nn.Module):
...
@@ -306,6 +303,32 @@ class SelfAttention(nn.Module):
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
if
CHUNK_SIZE
==
None
:
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
weights
=
fused_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
...
@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
)
def
forward
(
self
,
M_raw
,
M_mask
):
def
forward
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
.
shape
[
2
]
if
CHUNK_SIZE
is
None
:
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
m
=
self
.
layernormM
(
M_raw
.
transpose
(
-
2
,
-
3
))
m
=
self
.
global_attention
(
m
,
M_mask
.
transpose
(
-
1
,
-
2
))
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
=
M_raw
+
m
else
:
else
:
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
para_dim
=
M_raw
.
shape
[
2
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
...
@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
# [*, N, N, no_bins]
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
# [*, N, N, C_z]
para_dim
=
d
.
shape
[
1
]
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
d
=
self
.
linear
(
d
)
z
=
d
+
self
.
layer_norm_z
(
z
)
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
d
.
shape
[
1
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
...
@@ -1154,10 +1178,33 @@ class GlobalAttention(nn.Module):
...
@@ -1154,10 +1178,33 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
,
mask
):
def
forward
(
self
,
m
,
mask
):
if
CHUNK_SIZE
==
None
:
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
m
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
m
.
shape
[
1
]
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
...
fastfold/model/fastnn/template.py
View file @
e4119508
...
@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
...
@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_mask_row
=
scatter
(
mask
,
dim
=
1
)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
single_mask_col
=
scatter
(
mask
,
dim
=
2
)
for
i
in
range
(
z
.
shape
[
0
]):
single
=
z
[
i
].
unsqueeze
(
-
4
)
z
=
self
.
TriangleAttentionStartingNode
(
z
,
single_mask_row
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
z
=
row_to_col
(
z
)
z
=
self
.
TriangleAttentionEndingNode
(
z
,
single_mask_col
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
z
=
col_to_row
(
z
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
z
=
self
.
TriangleMultiplicationOutgoing
(
z
,
single_mask_row
)
z
=
row_to_col
(
z
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
z
=
self
.
TriangleMultiplicationIncoming
(
z
,
single_mask_col
)
single
=
row_to_col
(
single
)
z
=
self
.
PairTransition
(
z
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
z
=
col_to_row
(
z
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
i
]
=
single
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
...
@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
...
@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
z
[
0
].
cpu
()
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
mask
.
size
(
-
1
)
seq_length
=
mask
.
size
(
-
1
)
...
@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
...
@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_mask_row
=
scatter
(
mask
,
dim
=
1
,
drop_unused
=
True
)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
single_mask_col
=
scatter
(
mask
,
dim
=
2
,
drop_unused
=
True
)
for
i
in
range
(
z
[
0
].
shape
[
0
]):
torch
.
cuda
.
empty_cache
()
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
z
=
self
.
TriangleAttentionStartingNode
.
inplace
(
z
,
single_mask_row
)
single
=
row_to_col
(
single
)
z
[
0
]
=
row_to_col
(
z
[
0
])
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
z
=
self
.
TriangleAttentionEndingNode
.
inplace
(
z
,
single_mask_col
)
single
=
col_to_row
(
single
)
z
[
0
]
=
col_to_row
(
z
[
0
])
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
z
[
0
]
=
self
.
TriangleMultiplicationOutgoing
(
z
[
0
],
single_mask_row
)
single
=
row_to_col
(
single
)
z
[
0
]
=
row_to_col
(
z
[
0
])
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
z
[
0
]
=
self
.
TriangleMultiplicationIncoming
(
z
[
0
],
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
z
=
self
.
PairTransition
.
inplace
(
z
)
single
=
col_to_row
(
single
)
z
[
0
]
=
col_to_row
(
z
[
0
])
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk_size
=
chunk_size
)
z
[
0
]
=
z
[
0
].
to
(
mask
.
device
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
return
z
...
@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
...
@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
if
chunk_size
is
None
:
for
i
in
range
(
0
,
t
.
shape
[
0
]):
chunk_size
=
t
.
shape
[
0
]
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
for
i
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
if
t
.
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
.
shape
[
1
])
for
j
in
range
(
0
,
t
.
shape
[
1
],
chunk_new
):
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
])
else
:
t
[
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
])
return
t
return
t
def
inplace
(
def
inplace
(
...
@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
...
@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
if
chunk_size
is
None
:
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
]):
chunk_size
=
t
[
0
].
shape
[
0
]
t
[
0
][
i
]
=
self
.
layer_norm
(
t
[
0
][
i
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
],
chunk_size
):
if
t
[
0
].
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
[
0
].
shape
[
1
])
for
j
in
range
(
0
,
t
[
0
].
shape
[
1
],
chunk_new
):
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
else
:
t
[
0
][
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
return
t
return
t
fastfold/model/nn/dropout.py
View file @
e4119508
...
@@ -56,7 +56,7 @@ class Dropout(nn.Module):
...
@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape
[
bd
]
=
1
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
mask
=
self
.
dropout
(
mask
)
x
*
=
mask
x
=
x
*
mask
return
x
return
x
...
...
fastfold/model/nn/evoformer.py
View file @
e4119508
...
@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
...
@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
eps
=
eps
,
eps
=
eps
,
)
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
def
forward
(
self
,
...
...
inference.py
View file @
e4119508
...
@@ -227,9 +227,11 @@ def inference_multimer_model(args):
...
@@ -227,9 +227,11 @@ def inference_multimer_model(args):
)
)
output_dir_base
=
args
.
output_dir
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
# seed_torch(seed=1029)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
config
.
data
...
...
setup.py
View file @
e4119508
...
@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
...
@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_major
=
torch
.
version
.
hip
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
torch_binary_minor
=
torch
.
version
.
hip
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
...
...
tests/test_fastnn/test_msa_att_col.py
View file @
e4119508
...
@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
...
@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
m_fast
=
m_fast
[:,
:
-
padding_size
,
:]
m_fast
=
m_fast
[:,
:
-
padding_size
,
:]
error
=
torch
.
max
(
torch
.
abs
(
m_out
.
cuda
()
-
m_fast
))
error
=
torch
.
max
(
torch
.
abs
(
m_out
.
cuda
()
-
m_fast
))
assert
error
<
5
e-
5
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
assert
error
<
1
e-
4
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
tests/test_fastnn/test_template_embedder.py
View file @
e4119508
...
@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
...
@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
32
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
4
])
# should set 4 to test offload
@
pytest
.
mark
.
parametrize
(
'inplace'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'inplace'
,
[
False
,
True
])
def
test_state_dict
(
world_size
,
chunk_size
,
inplace
,
get_openfold_module_and_data
):
def
test_state_dict
(
world_size
,
chunk_size
,
inplace
,
get_openfold_module_and_data
):
run_func
=
partial
(
_test_template_embedder
,
world_size
=
world_size
,
chunk_size
=
chunk_size
,
run_func
=
partial
(
_test_template_embedder
,
world_size
=
world_size
,
chunk_size
=
chunk_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment