Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e4119508
Commit
e4119508
authored
Feb 14, 2023
by
zhuww
Browse files
fix multimer bug
parent
614e2763
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
358 additions
and
245 deletions
+358
-245
README.md
README.md
+4
-8
fastfold/common/protein.py
fastfold/common/protein.py
+153
-71
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+1
-1
fastfold/data/templates.py
fastfold/data/templates.py
+4
-1
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+12
-12
fastfold/model/fastnn/embedders_multimer.py
fastfold/model/fastnn/embedders_multimer.py
+24
-23
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+2
-2
fastfold/model/fastnn/kernel/layer_norm.py
fastfold/model/fastnn/kernel/layer_norm.py
+18
-0
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+3
-3
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+100
-53
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+30
-61
fastfold/model/nn/dropout.py
fastfold/model/nn/dropout.py
+1
-1
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+0
-5
inference.py
inference.py
+2
-0
setup.py
setup.py
+2
-2
tests/test_fastnn/test_msa_att_col.py
tests/test_fastnn/test_msa_att_col.py
+1
-1
tests/test_fastnn/test_template_embedder.py
tests/test_fastnn/test_template_embedder.py
+1
-1
No files found.
README.md
View file @
e4119508
...
...
@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
## Installation
To install
and use
FastFold, you will need:
To install FastFold, you will need:
+
Python 3.8 or 3.9.
+
[
NVIDIA CUDA
](
https://developer.nvidia.com/cuda-downloads
)
11.1 or above
+
PyTorch 1.1
0
or above
+
PyTorch 1.1
2
or above
For now, You can install FastFold:
...
...
@@ -45,14 +45,10 @@ python setup.py install
#### Advanced
To leverage the power of FastFold, we recommend you build
[
Triton
](
)
from source.
**[NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.4 or above is needed.**
To leverage the power of FastFold, we recommend you to install
[
Triton
](
https://github.com/openai/triton
)
.
```
bash
git clone https://github.com/openai/triton.git ~/triton
cd
~/triton/python
pip
install
-e
.
pip
install
triton
==
2.0.0.dev20221005
```
...
...
fastfold/common/protein.py
View file @
e4119508
...
...
@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
...
...
@@ -46,11 +49,22 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
...
...
@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
...
...
@@ -72,32 +85,33 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
models
=
list
(
structure
.
get_models
())
if
len
(
models
)
!=
1
:
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
else
:
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
"Only single chain PDBs are supported when chain_id not specified. "
f
"Found
{
len
(
chains
)
}
chains."
)
else
:
chain
=
chains
[
0
]
atom_positions
=
[]
aatype
=
[]
atom_mask
=
[]
residue_index
=
[]
chain_ids
=
[]
b_factors
=
[]
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
...
...
@@ -106,28 +120,40 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
bfactor
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
)
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
tag_re
=
r
'(\[[A-Z]+\]\n)'
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
atoms
=
[
'N'
,
'CA'
,
'C'
]
...
...
@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
seq
[
i
]
=
'X'
aatype
=
np
.
array
([
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
])
elif
(
"[TERTIARY]"
==
g
[
0
]):
elif
(
"[TERTIARY]"
==
g
[
0
]):
tertiary
=
[]
for
axis
in
range
(
3
):
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary_np
=
np
.
array
(
tertiary
)
atom_positions
=
np
.
zeros
(
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)).
astype
(
np
.
float32
)
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
tertiary_np
[:,
i
::
3
]))
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
tertiary_np
[:,
i
::
3
])
)
atom_positions
*=
PICO_TO_ANGSTROM
elif
(
"[MASK]"
==
g
[
0
]):
elif
(
"[MASK]"
==
g
[
0
]):
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
atom_mask
=
np
.
zeros
((
len
(
mask
),
residue_constants
.
atom_type_num
,
)).
astype
(
np
.
float32
)
atom_mask
=
np
.
zeros
(
(
len
(
mask
),
residue_constants
.
atom_type_num
,)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
*=
mask
[...,
None
]
...
...
@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_index
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
...
...
@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
chain_i
d
=
"A"
last_
chain_i
ndex
=
chain_index
[
0
]
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]):
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
):
if
mask
<
0.5
:
continue
...
...
@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
alt_loc
=
""
insertion_code
=
""
occupancy
=
1.00
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
charge
=
""
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
# Close the chain.
chain_end
=
"TER"
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
)
pdb_lines
.
append
(
chain_termination_line
)
pdb_lines
.
append
(
"ENDMDL"
)
# Close the final chain.
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
-
1
]),
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
...
...
@@ -258,24 +328,36 @@ def from_prediction(
features
:
FeatureDict
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
False
,
)
->
Protein
:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
)
\ No newline at end of file
fastfold/data/data_pipeline.py
View file @
e4119508
...
...
@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
fasta_path
)
...
...
fastfold/data/templates.py
View file @
e4119508
...
...
@@ -866,6 +866,9 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
if
hit
.
sum_probs
is
None
:
features
[
'template_sum_probs'
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
...
...
fastfold/distributed/comm.py
View file @
e4119508
...
...
@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
s
=
1
)
->
Tensor
:
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunk
_size
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
...
...
@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
1
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
1
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
1
)
for
i
in
range
(
chunk
s
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
...
...
@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunk
s
,
dim
=
0
))
tensor_list
.
extend
(
t
.
chunk
(
chunk
_size
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunk
s
,
dim
=
0
)
chunk_tensor
=
tensor
.
chunk
(
chunk
_size
,
dim
=
0
)
for
i
in
range
(
chunk
s
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
s
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
for
i
in
range
(
chunk
_size
):
_chunk_list
=
[
tensor_list
[
j
*
chunk
_size
+
i
]
for
j
in
range
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
...
...
@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
s
:
int
=
None
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunk
_size
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
else
:
if
chunk
s
is
None
:
if
chunk
_size
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
s
=
chunk
s
)
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunk
_size
=
chunk
_size
)
return
input
...
...
fastfold/model/fastnn/embedders_multimer.py
View file @
e4119508
...
...
@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
template_pair_embeddings
=
torch
.
zeros
((
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
z
.
device
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
...
...
@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
pair_
act
=
self
.
template_pair_embedder
(
pair_
embedding
=
self
.
template_pair_embedder
(
template_dgram
,
aatype_one_hot
,
z
,
...
...
@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
unit_vector
,
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
=
template_pair_embeddings
+
self
.
template_pair_stack
(
pair_embedding
,
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
).
squeeze
(
0
)
else
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
+=
self
.
template_pair_stack
.
inplace
(
[
pair_embedding
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
squeeze
(
0
)
single_template_embeds
.
update
(
self
.
template_single_embedder
(
single_template_feats
,
...
...
@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
,
)
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
else
:
template_embeds
[
"template_pair_embedding"
]
=
[
template_embeds
[
"template_pair_embedding"
]]
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
.
inplace
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
to
(
z
.
device
)
# [*, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
torch
.
sum
(
template_embeds
[
"template_pair_embedding"
],
dim
=-
4
)
/
n_templ
template_embeds
[
"template_pair_embedding"
]
=
torch
.
nn
.
functional
.
relu
(
template_embeds
[
"template_pair_embedding"
])
template_embeds
[
"template_pair_embedding"
]
=
self
.
linear_t
(
template_embeds
[
"template_pair_embedding"
])
template_pair_embeddings
=
template_pair_embeddings
/
n_templ
template_pair_embeddings
=
torch
.
nn
.
functional
.
relu
(
template_pair_embeddings
)
template_pair_embeddings
=
self
.
linear_t
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
template_pair_embeddings
return
template_embeds
fastfold/model/fastnn/evoformer.py
View file @
e4119508
...
...
@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
if
self
.
is_multimer
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
else
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
s
=
4
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/kernel/layer_norm.py
View file @
e4119508
...
...
@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
len
(
input
.
shape
)
>=
3
and
input
.
shape
[
-
3
]
>
4000
:
out
=
torch
.
empty_like
(
input
)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size
=
min
(
4000
*
4000
//
input
.
shape
[
-
3
],
(
input
.
shape
[
-
3
]
+
1
)
//
2
)
if
len
(
input
.
shape
)
==
3
:
for
i
in
range
(
input
.
shape
[
-
3
]):
out
[
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
i
:
i
+
chunk_size
])
elif
len
(
input
.
shape
)
==
4
:
for
j
in
range
(
input
.
shape
[
-
4
]):
for
i
in
range
(
0
,
input
.
shape
[
-
3
],
chunk_size
):
out
[
j
,
i
:
i
+
chunk_size
]
=
self
.
kernel_forward
(
input
[
j
,
i
:
i
+
chunk_size
])
else
:
raise
RuntimeError
(
"Shape"
+
input
.
shape
+
"not implemented for layernorm yet!"
)
return
out
else
:
return
self
.
kernel_forward
(
input
)
def
kernel_forward
(
self
,
input
):
if
_triton_available
:
return
LayerNormTritonFunc
.
apply
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
...
...
fastfold/model/fastnn/msa.py
View file @
e4119508
...
...
@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
s
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
s
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk
_size
=
chunk_size
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e4119508
...
...
@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
para_dim
=
src
.
shape
[
1
]
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
self
.
norm
(
src
)
out
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
out
)))
else
:
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
src
.
shape
[
1
]
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
...
...
@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
if
CHUNK_SIZE
==
None
:
out
=
torch
.
einsum
(
'bsid, bsje->bijde'
,
left_act
,
right_act_all
)
out
=
rearrange
(
out
,
'b i j d e -> b i j (d e)'
)
out
=
self
.
o_linear
(
out
)
Z
=
out
/
norm
else
:
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
...
...
@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
bias
=
nonbatched_bias
[
0
]
...
...
@@ -306,6 +303,32 @@ class SelfAttention(nn.Module):
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
if
CHUNK_SIZE
==
None
:
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
weights
=
fused_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
...
@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
def
forward
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
.
shape
[
2
]
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
m
=
self
.
layernormM
(
M_raw
.
transpose
(
-
2
,
-
3
))
m
=
self
.
global_attention
(
m
,
M_mask
.
transpose
(
-
1
,
-
2
))
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
=
M_raw
+
m
else
:
chunk_size
=
CHUNK_SIZE
para_dim
=
M_raw
.
shape
[
2
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
...
@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
# [*, N, N, C_z]
para_dim
=
d
.
shape
[
1
]
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
d
=
self
.
linear
(
d
)
z
=
d
+
self
.
layer_norm_z
(
z
)
else
:
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
d
.
shape
[
1
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
...
...
@@ -1154,10 +1178,33 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
,
mask
):
if
CHUNK_SIZE
==
None
:
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
m
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
...
fastfold/model/fastnn/template.py
View file @
e4119508
...
...
@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for
i
in
range
(
z
.
shape
[
0
]):
single
=
z
[
i
].
unsqueeze
(
-
4
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
i
]
=
single
single_mask_row
=
scatter
(
mask
,
dim
=
1
)
single_mask_col
=
scatter
(
mask
,
dim
=
2
)
z
=
self
.
TriangleAttentionStartingNode
(
z
,
single_mask_row
)
z
=
row_to_col
(
z
)
z
=
self
.
TriangleAttentionEndingNode
(
z
,
single_mask_col
)
z
=
col_to_row
(
z
)
z
=
self
.
TriangleMultiplicationOutgoing
(
z
,
single_mask_row
)
z
=
row_to_col
(
z
)
z
=
self
.
TriangleMultiplicationIncoming
(
z
,
single_mask_col
)
z
=
self
.
PairTransition
(
z
)
z
=
col_to_row
(
z
)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
...
...
@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
):
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
z
[
0
].
cpu
()
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
mask
.
size
(
-
1
)
...
...
@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for
i
in
range
(
z
[
0
].
shape
[
0
]):
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
single_mask_row
=
scatter
(
mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
mask
,
dim
=
2
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
z
=
self
.
TriangleAttentionStartingNode
.
inplace
(
z
,
single_mask_row
)
z
[
0
]
=
row_to_col
(
z
[
0
])
z
=
self
.
TriangleAttentionEndingNode
.
inplace
(
z
,
single_mask_col
)
z
[
0
]
=
col_to_row
(
z
[
0
])
z
[
0
]
=
self
.
TriangleMultiplicationOutgoing
(
z
[
0
],
single_mask_row
)
z
[
0
]
=
row_to_col
(
z
[
0
])
z
[
0
]
=
self
.
TriangleMultiplicationIncoming
(
z
[
0
],
single_mask_col
)
z
=
self
.
PairTransition
.
inplace
(
z
)
z
[
0
]
=
col_to_row
(
z
[
0
])
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
z
[
0
].
to
(
mask
.
device
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunk_size
=
chunk_size
)
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
...
...
@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
if
chunk_size
is
None
:
chunk_size
=
t
.
shape
[
0
]
for
i
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
if
t
.
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
.
shape
[
1
])
for
j
in
range
(
0
,
t
.
shape
[
1
],
chunk_new
):
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
])
else
:
t
[
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
])
for
i
in
range
(
0
,
t
.
shape
[
0
]):
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
return
t
def
inplace
(
...
...
@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
if
chunk_size
is
None
:
chunk_size
=
t
[
0
].
shape
[
0
]
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
],
chunk_size
):
if
t
[
0
].
shape
[
1
]
>
4000
:
chunk_new
=
int
(
4000
*
4000
/
t
[
0
].
shape
[
1
])
for
j
in
range
(
0
,
t
[
0
].
shape
[
1
],
chunk_new
):
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
,
j
:
j
+
chunk_new
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
else
:
t
[
0
][
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
]):
t
[
0
][
i
]
=
self
.
layer_norm
(
t
[
0
][
i
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
return
t
fastfold/model/nn/dropout.py
View file @
e4119508
...
...
@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
x
*
=
mask
x
=
x
*
mask
return
x
...
...
fastfold/model/nn/evoformer.py
View file @
e4119508
...
...
@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
eps
=
eps
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
...
...
inference.py
View file @
e4119508
...
...
@@ -227,9 +227,11 @@ def inference_multimer_model(args):
)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
# seed_torch(seed=1029)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
...
...
setup.py
View file @
e4119508
...
...
@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
torch_binary_major
=
torch
.
version
.
hip
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
hip
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
...
...
tests/test_fastnn/test_msa_att_col.py
View file @
e4119508
...
...
@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
m_fast
=
m_fast
[:,
:
-
padding_size
,
:]
error
=
torch
.
max
(
torch
.
abs
(
m_out
.
cuda
()
-
m_fast
))
assert
error
<
5
e-
5
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
assert
error
<
1
e-
4
,
f
"Test m failed at chunk size:
{
chunk_size
}
. The position dif is
{
error
}
"
tests/test_fastnn/test_template_embedder.py
View file @
e4119508
...
...
@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
32
])
@
pytest
.
mark
.
parametrize
(
'chunk_size'
,
[
None
,
4
])
# should set 4 to test offload
@
pytest
.
mark
.
parametrize
(
'inplace'
,
[
False
,
True
])
def
test_state_dict
(
world_size
,
chunk_size
,
inplace
,
get_openfold_module_and_data
):
run_func
=
partial
(
_test_template_embedder
,
world_size
=
world_size
,
chunk_size
=
chunk_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment