Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e4119508
Commit
e4119508
authored
Feb 14, 2023
by
zhuww
Browse files
fix multimer bug
parent
614e2763
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
358 additions
and
245 deletions
+358
-245
README.md
README.md
+4
-8
fastfold/common/protein.py
fastfold/common/protein.py
+153
-71
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+1
-1
fastfold/data/templates.py
fastfold/data/templates.py
+4
-1
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+12
-12
fastfold/model/fastnn/embedders_multimer.py
fastfold/model/fastnn/embedders_multimer.py
+24
-23
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+2
-2
fastfold/model/fastnn/kernel/layer_norm.py
fastfold/model/fastnn/kernel/layer_norm.py
+18
-0
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+3
-3
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+100
-53
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+30
-61
fastfold/model/nn/dropout.py
fastfold/model/nn/dropout.py
+1
-1
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+0
-5
inference.py
inference.py
+2
-0
setup.py
setup.py
+2
-2
tests/test_fastnn/test_msa_att_col.py
tests/test_fastnn/test_msa_att_col.py
+1
-1
tests/test_fastnn/test_template_embedder.py
tests/test_fastnn/test_template_embedder.py
+1
-1
No files found.
README.md
View file @
e4119508
...
@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
...
@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
## Installation
## Installation
To install
and use
FastFold, you will need:
To install FastFold, you will need:
+
Python 3.8 or 3.9.
+
Python 3.8 or 3.9.
+
[
NVIDIA CUDA
](
https://developer.nvidia.com/cuda-downloads
)
11.1 or above
+
[
NVIDIA CUDA
](
https://developer.nvidia.com/cuda-downloads
)
11.1 or above
+
PyTorch 1.1
0
or above
+
PyTorch 1.1
2
or above
For now, You can install FastFold:
For now, You can install FastFold:
...
@@ -45,14 +45,10 @@ python setup.py install
...
@@ -45,14 +45,10 @@ python setup.py install
#### Advanced
#### Advanced
To leverage the power of FastFold, we recommend you build
[
Triton
](
)
from source.
To leverage the power of FastFold, we recommend you to install
[
Triton
](
https://github.com/openai/triton
)
.
**[NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.4 or above is needed.**
```
bash
```
bash
git clone https://github.com/openai/triton.git ~/triton
pip
install
triton
==
2.0.0.dev20221005
cd
~/triton/python
pip
install
-e
.
```
```
...
...
fastfold/common/protein.py
View file @
e4119508
...
@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
...
@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
class
Protein
:
...
@@ -45,11 +48,22 @@ class Protein:
...
@@ -45,11 +48,22 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# representing the displacement of the residue from its ground truth mean
# value.
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
...
@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
Args:
pdb_str: The contents of the pdb file
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
chain_id: If chain_id is specified (e.g. A), then only that chain is
will be parsed). If chain_id is specified (e.g. A), then only that chain
parsed. Else, all chains are parsed.
is parsed.
Returns:
Returns:
A new `Protein` parsed from the pdb contents.
A new `Protein` parsed from the pdb contents.
...
@@ -72,62 +85,75 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -72,62 +85,75 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
models
=
list
(
structure
.
get_models
())
models
=
list
(
structure
.
get_models
())
if
len
(
models
)
!=
1
:
if
len
(
models
)
!=
1
:
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
model
=
models
[
0
]
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
else
:
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
"Only single chain PDBs are supported when chain_id not specified. "
f
"Found
{
len
(
chains
)
}
chains."
)
else
:
chain
=
chains
[
0
]
atom_positions
=
[]
atom_positions
=
[]
aatype
=
[]
aatype
=
[]
atom_mask
=
[]
atom_mask
=
[]
residue_index
=
[]
residue_index
=
[]
chain_ids
=
[]
b_factors
=
[]
b_factors
=
[]
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
raise
ValueError
(
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
)
residue_constants
.
restype_num
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
restype_idx
=
residue_constants
.
restype_order
.
get
(
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_shortname
,
residue_constants
.
restype_num
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
)
for
atom
in
res
:
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
if
atom
.
name
not
in
residue_constants
.
atom_types
:
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
for
atom
in
res
:
if
atom
.
name
not
in
residue_constants
.
atom_types
:
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
aatype
.
append
(
restype_idx
)
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
bfactor
atom_positions
.
append
(
pos
)
if
np
.
sum
(
mask
)
<
0.5
:
atom_mask
.
append
(
mask
)
# If no known atom positions are reported for the residue then skip it.
residue_index
.
append
(
res
.
id
[
1
])
continue
chain_ids
.
append
(
chain
.
id
)
aatype
.
append
(
restype_idx
)
b_factors
.
append
(
res_b_factors
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
# Chain IDs are usually characters so map these to ints
residue_index
.
append
(
res
.
id
[
1
])
unique_chain_ids
=
np
.
unique
(
chain_ids
)
b_factors
.
append
(
res_b_factors
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
b_factors
=
np
.
array
(
b_factors
),
)
)
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
tag_re
=
r
'(\[[A-Z]+\]\n)'
tag_re
=
r
'(\[[A-Z]+\]\n)'
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
atoms
=
[
'N'
,
'CA'
,
'C'
]
atoms
=
[
'N'
,
'CA'
,
'C'
]
...
@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
seq
[
i
]
=
'X'
seq
[
i
]
=
'X'
aatype
=
np
.
array
([
aatype
=
np
.
array
([
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
residue_constants
.
restype_order
.
get
(
for
res_symbol
in
seq
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
])
])
elif
(
"[TERTIARY]"
==
g
[
0
]):
elif
(
"[TERTIARY]"
==
g
[
0
]):
tertiary
=
[]
tertiary
=
[]
for
axis
in
range
(
3
):
for
axis
in
range
(
3
):
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary_np
=
np
.
array
(
tertiary
)
tertiary_np
=
np
.
array
(
tertiary
)
atom_positions
=
np
.
zeros
(
atom_positions
=
np
.
zeros
(
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)).
astype
(
np
.
float32
)
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
for
i
,
atom
in
enumerate
(
atoms
):
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
tertiary_np
[:,
i
::
3
]))
np
.
transpose
(
tertiary_np
[:,
i
::
3
])
)
atom_positions
*=
PICO_TO_ANGSTROM
atom_positions
*=
PICO_TO_ANGSTROM
elif
(
"[MASK]"
==
g
[
0
]):
elif
(
"[MASK]"
==
g
[
0
]):
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
atom_mask
=
np
.
zeros
((
atom_mask
=
np
.
zeros
(
len
(
mask
),
(
len
(
mask
),
residue_constants
.
atom_type_num
,)
residue_constants
.
atom_type_num
,
).
astype
(
np
.
float32
)
)).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
for
i
,
atom
in
enumerate
(
atoms
):
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
*=
mask
[...,
None
]
atom_mask
*=
mask
[...,
None
]
...
@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_index
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
"""Converts a `Protein` instance to a PDB string.
...
@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
...
@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
pdb_lines
.
append
(
"MODEL 1"
)
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
atom_index
=
1
chain_i
d
=
"A"
last_
chain_i
ndex
=
chain_index
[
0
]
# Add all atom sites.
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
b_factors
[
i
]):
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
):
if
mask
<
0.5
:
if
mask
<
0.5
:
continue
continue
...
@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
...
@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
alt_loc
=
""
alt_loc
=
""
insertion_code
=
""
insertion_code
=
""
occupancy
=
1.00
occupancy
=
1.00
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
charge
=
""
charge
=
""
# PDB is a columnar format, every space matters here!
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
atom_line
=
(
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
pdb_lines
.
append
(
atom_line
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
atom_index
+=
1
# Close the chain.
# Close the final chain.
chain_end
=
"TER"
pdb_lines
.
append
(
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
_chain_end
(
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
)
atom_index
,
pdb_lines
.
append
(
chain_termination_line
)
res_1to3
(
aatype
[
-
1
]),
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
...
@@ -258,24 +328,36 @@ def from_prediction(
...
@@ -258,24 +328,36 @@ def from_prediction(
features
:
FeatureDict
,
features
:
FeatureDict
,
result
:
ModelOutput
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
False
,
)
->
Protein
:
)
->
Protein
:
"""Assembles a protein from a prediction.
"""Assembles a protein from a prediction.
Args:
Args:
features: Dictionary holding model inputs.
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
Returns:
A protein instance.
A protein instance.
"""
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
b_factors
=
b_factors
,
)
)
\ No newline at end of file
fastfold/data/data_pipeline.py
View file @
e4119508
...
@@ -249,7 +249,7 @@ def run_msa_tool(
...
@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences
:
Optional
[
int
]
=
None
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
else
:
result
=
msa_runner
.
query
(
fasta_path
)
result
=
msa_runner
.
query
(
fasta_path
)
...
...
fastfold/data/templates.py
View file @
e4119508
...
@@ -866,7 +866,10 @@ def _process_single_hit(
...
@@ -866,7 +866,10 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
_zero_center_positions
=
_zero_center_positions
,
)
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
if
hit
.
sum_probs
is
None
:
features
[
'template_sum_probs'
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# mmCIF file, but the template features for the chain we want were still
...
...
fastfold/distributed/comm.py
View file @
e4119508
...
@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
s
=
1
)
->
Tensor
:
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
_size
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
return
tensor
...
@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
...
@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
tensor_list
=
[]
for
t
in
world_list
:
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
1
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
1
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
1
)
for
i
in
range
(
chunk
s
):
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
dist
.
all_gather
(
list
(
_chunk_list
),
...
@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
...
@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
tensor_list
=
[]
for
t
in
world_list
:
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
0
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
0
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
0
)
for
i
in
range
(
chunk
s
):
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
dist
.
all_gather
(
list
(
_chunk_list
),
...
@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
...
@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
s
:
int
=
None
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
_size
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
input
=
Gather
.
apply
(
input
,
dim
)
else
:
else
:
if
chunk
s
is
None
:
if
chunk
_size
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
input
=
_gather
(
input
,
dim
=
dim
)
else
:
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
s
=
chunk
s
)
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
_size
=
chunk
_size
)
return
input
return
input
...
...
fastfold/model/fastnn/embedders_multimer.py
View file @
e4119508
...
@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
):
):
template_embeds
=
[]
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
template_pair_embeddings
=
torch
.
zeros
((
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
z
.
device
)
for
i
in
range
(
n_templ
):
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
single_template_feats
=
tensor_tree_map
(
...
@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
rigid_vec
.
normalized
()
pair_
act
=
self
.
template_pair_embedder
(
pair_
embedding
=
self
.
template_pair_embedder
(
template_dgram
,
template_dgram
,
aatype_one_hot
,
aatype_one_hot
,
z
,
z
,
...
@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
unit_vector
,
unit_vector
,
)
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
=
template_pair_embeddings
+
self
.
template_pair_stack
(
pair_embedding
,
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
).
squeeze
(
0
)
else
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
+=
self
.
template_pair_stack
.
inplace
(
[
pair_embedding
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
squeeze
(
0
)
single_template_embeds
.
update
(
single_template_embeds
.
update
(
self
.
template_single_embedder
(
self
.
template_single_embedder
(
single_template_feats
,
single_template_feats
,
...
@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
,
template_embeds
,
)
)
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
else
:
template_embeds
[
"template_pair_embedding"
]
=
[
template_embeds
[
"template_pair_embedding"
]]
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
.
inplace
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
to
(
z
.
device
)
# [*, N, N, C_z]
# [*, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
torch
.
sum
(
template_embeds
[
"template_pair_embedding"
],
dim
=-
4
)
/
n_templ
template_pair_embeddings
=
template_pair_embeddings
/
n_templ
template_embeds
[
"template_pair_embedding"
]
=
torch
.
nn
.
functional
.
relu
(
template_embeds
[
"template_pair_embedding"
])
template_pair_embeddings
=
torch
.
nn
.
functional
.
relu
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
self
.
linear_t
(
template_embeds
[
"template_pair_embedding"
])
template_pair_embeddings
=
self
.
linear_t
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
template_pair_embeddings
return
template_embeds
return
template_embeds
fastfold/model/fastnn/evoformer.py
View file @
e4119508
...
@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
...
@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
if
self
.
is_multimer
:
if
self
.
is_multimer
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
else
:
else
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
s
=
4
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/kernel/layer_norm.py
View file @
e4119508
...
@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
...
@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
len
(
input
.
shape
)
>=
3
and
input
.
shape
[
-
3
]
>
4000
:
out
=
torch
.
empty_like
(
input
)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size
=
min
(
4000
*
4000
//
input
.
shape
[
-
3
],
(
input
.
shape
[
-
3
]
+
1
)
//
2
)
if
len
(
input
.
shape
)
==
3
:
for
i
in
range
(
input
.
shape
[
-
3
]):
out
[
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
i
:
i
+
chunk_size
])
elif
len
(
input
.
shape
)
==
4
:
for
j
in
range
(
input
.
shape
[
-
4
]):
for
i
in
range
(
0
,
input
.
shape
[
-
3
],
chunk_size
):
out
[
j
,
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
j
,
i
:
i
+
chunk_size
])
else
:
raise
RuntimeError
(
"Shape"
+
input
.
shape
+
"not implemented for layernorm yet!"
)
return
out
else
:
return
self
.
kernel_forward
(
input
)
def
kernel_forward
(
self
,
input
):
if
_triton_available
:
if
_triton_available
:
return
LayerNormTritonFunc
.
apply
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
return
LayerNormTritonFunc
.
apply
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
self
.
eps
)
...
...
fastfold/model/fastnn/msa.py
View file @
e4119508
...
@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
s
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e4119508
...
@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
...
@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
def
forward
(
self
,
src
):
para_dim
=
src
.
shape
[
1
]
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
self
.
norm
(
src
)
out
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
out
)))
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
src
.
shape
[
1
]
out
=
torch
.
empty_like
(
src
)
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
if
DEBUG
and
ax
>
10
:
break
break
x
=
self
.
norm
(
src
[:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
norm
(
src
[:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
out
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
x
out
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
x
out
.
add_
(
src
)
out
.
add_
(
src
)
return
out
return
out
...
@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
...
@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
right_act_all
=
M_mask
*
right_act_all
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
torch
.
einsum
(
'bsid, bsje->bijde'
,
left_act
,
right_act_all
)
out
=
rearrange
(
out
,
'b i j d e -> b i j (d e)'
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
out
=
self
.
o_linear
(
out
)
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
Z
=
out
/
norm
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
else
:
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
para_dim
=
left_act
.
shape
[
2
]
O
=
self
.
o_linear
(
O
)
chunk_size
=
CHUNK_SIZE
norm0
=
norm
[:,
ax
:
ax
+
chunk_size
,
:,
:]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
Z
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
O
/
norm0
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
O
=
self
.
o_linear
(
O
)
norm0
=
norm
[:,
ax
:
ax
+
chunk_size
,
:,
:]
Z
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
O
/
norm0
return
Z
+
Z_raw
return
Z
+
Z_raw
...
@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
...
@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
bias
=
nonbatched_bias
[
0
]
bias
=
nonbatched_bias
[
0
]
...
@@ -306,7 +303,33 @@ class SelfAttention(nn.Module):
...
@@ -306,7 +303,33 @@ class SelfAttention(nn.Module):
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
output
=
[]
if
CHUNK_SIZE
==
None
:
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
weights
=
fused_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
...
@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
...
@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
)
def
forward
(
self
,
M_raw
,
M_mask
):
def
forward
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
.
shape
[
2
]
if
CHUNK_SIZE
is
None
:
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
m
=
self
.
layernormM
(
M_raw
.
transpose
(
-
2
,
-
3
))
m
=
self
.
global_attention
(
m
,
M_mask
.
transpose
(
-
1
,
-
2
))
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
=
M_raw
+
m
else
:
else
:
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
para_dim
=
M_raw
.
shape
[
2
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
...
@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
# [*, N, N, no_bins]
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
# [*, N, N, C_z]
para_dim
=
d
.
shape
[
1
]
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
d
=
self
.
linear
(
d
)
z
=
d
+
self
.
layer_norm_z
(
z
)
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
d
.
shape
[
1
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
...
@@ -1154,41 +1178,64 @@ class GlobalAttention(nn.Module):
...
@@ -1154,41 +1178,64 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
,
mask
):
def
forward
(
self
,
m
,
mask
):
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
)
q
=
q
*
self
.
scaling
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m
_part
).
chunk
(
2
,
dim
=-
1
)
k
,
v
=
self
.
to_kv
(
m
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask
_part
)
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m
_part
)
gate_values
=
self
.
gating_linear
(
m
)
weighted_avg
=
bias_sigmod_ele
(
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m_part
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask_part
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
torch
.
cat
(
output
,
dim
=
1
)
m
=
torch
.
cat
(
output
,
dim
=
1
)
return
m
return
m
...
...
fastfold/model/fastnn/template.py
View file @
e4119508
...
@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
...
@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_mask_row
=
scatter
(
mask
,
dim
=
1
)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
single_mask_col
=
scatter
(
mask
,
dim
=
2
)
for
i
in
range
(
z
.
shape
[
0
]):
single
=
z
[
i
].
unsqueeze
(
-
4
)
z
=
self
.
TriangleAttentionStartingNode
(
z
,
single_mask_row
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
z
=
row_to_col
(
z
)
z
=
self
.
TriangleAttentionEndingNode
(
z
,
single_mask_col
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
z
=
col_to_row
(
z
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
z
=
self
.
TriangleMultiplicationOutgoing
(
z
,
single_mask_row
)
z
=
row_to_col
(
z
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
z
=
self
.
TriangleMultiplicationIncoming
(
z
,
single_mask_col
)
single
=
row_to_col
(
single
)
z
=
self
.
PairTransition
(
z
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
z
=
col_to_row
(
z
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
i
]
=
single
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
...
@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
...
@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
z
[
0
].
cpu
()
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
mask
.
size
(
-
1
)
seq_length
=
mask
.
size
(
-
1
)
...
@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
...
@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_mask_row
=
scatter
(
mask
,
dim
=
1
,
drop_unused
=
True
)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
single_mask_col
=
scatter
(
mask
,
dim
=
2
,
drop_unused
=
True
)
for
i
in
range
(
z
[
0
].
shape
[
0
]):
torch
.
cuda
.
empty_cache
()
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
z
=
self
.
TriangleAttentionStartingNode
.
inplace
(
z
,
single_mask_row
)
z
[
0
]
=
row_to_col
(
z
[
0
])
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
z
=
self
.
TriangleAttentionEndingNode
.
inplace
(
z
,
single_mask_col
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
z
[
0
]
=
col_to_row
(
z
[
0
])
z
[
0
]
=
self
.
TriangleMultiplicationOutgoing
(
z
[
0
],
single_mask_row
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
z
[
0
]
=
row_to_col
(
z
[
0
])
single
=
row_to_col
(
single
)
z
[
0
]
=
self
.
TriangleMultiplicationIncoming
(
z
[
0
],
single_mask_col
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
z
=
self
.
PairTransition
.
inplace
(
z
)
single
=
col_to_row
(
single
)
z
[
0
]
=
col_to_row
(
z
[
0
])
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk_size
=
chunk_size
)
z
[
0
]
=
z
[
0
].
to
(
mask
.
device
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
return
z
...
@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
...
@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
if
chunk_size
is
None
:
for
i
in
range
(
0
,
t
.
shape
[
0
]):
chunk_size
=
t
.
shape
[
0
]
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
for
i
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
if
t
.
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
.
shape
[
1
])
for
j
in
range
(
0
,
t
.
shape
[
1
],
chunk_new
):
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
])
else
:
t
[
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
])
return
t
return
t
def
inplace
(
def
inplace
(
...
@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
...
@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
if
chunk_size
is
None
:
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
]):
chunk_size
=
t
[
0
].
shape
[
0
]
t
[
0
][
i
]
=
self
.
layer_norm
(
t
[
0
][
i
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
],
chunk_size
):
if
t
[
0
].
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
[
0
].
shape
[
1
])
for
j
in
range
(
0
,
t
[
0
].
shape
[
1
],
chunk_new
):
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
else
:
t
[
0
][
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
return
t
return
t
fastfold/model/nn/dropout.py
View file @
e4119508
...
@@ -56,7 +56,7 @@ class Dropout(nn.Module):
...
@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape
[
bd
]
=
1
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
mask
=
self
.
dropout
(
mask
)
x
*
=
mask
x
=
x
*
mask
return
x
return
x
...
...
fastfold/model/nn/evoformer.py
View file @
e4119508
...
@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
...
@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
eps
=
eps
,
eps
=
eps
,
)
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
def
forward
(
self
,
...
...
inference.py
View file @
e4119508
...
@@ -227,9 +227,11 @@ def inference_multimer_model(args):
...
@@ -227,9 +227,11 @@ def inference_multimer_model(args):
)
)
output_dir_base
=
args
.
output_dir
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
# seed_torch(seed=1029)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
config
.
data
...
...
setup.py
View file @
e4119508
...
@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
...
@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_major
=
torch
.
version
.
hip
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
torch_binary_minor
=
torch
.
version
.
hip
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
...
...
tests/test_fastnn/test_msa_att_col.py
View file @
e4119508
...
@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
...
@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
m_fast
=
m_fast
[:,
:
-
padding_size
,
:]
m_fast
=
m_fast
[:,
:
-
padding_size
,
:]
error
=
torch
.
max
(
torch
.
abs
(
m_out
.
cuda
()
-
m_fast
))
error
=
torch
.
max
(
torch
.
abs
(
m_out
.
cuda
()
-
m_fast
))
assert
error
<
5
e-
5
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
assert
error
<
1
e-
4
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
tests/test_fastnn/test_template_embedder.py
View file @
e4119508
...
@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
...
@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
32
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
4
])
# should set 4 to test offload
@
pytest
.
mark
.
parametrize
(
'inplace'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'inplace'
,
[
False
,
True
])
def
test_state_dict
(
world_size
,
chunk_size
,
inplace
,
get_openfold_module_and_data
):
def
test_state_dict
(
world_size
,
chunk_size
,
inplace
,
get_openfold_module_and_data
):
run_func
=
partial
(
_test_template_embedder
,
world_size
=
world_size
,
chunk_size
=
chunk_size
,
run_func
=
partial
(
_test_template_embedder
,
world_size
=
world_size
,
chunk_size
=
chunk_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment