Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e71c1b14
Commit
e71c1b14
authored
Jan 12, 2024
by
Jennifer
Browse files
initial compatibility changes for upgrading multimer
parent
9a07b7f9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
21 deletions
+22
-21
.gitignore
.gitignore
+1
-1
environment.yml
environment.yml
+10
-9
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+6
-6
openfold/data/templates.py
openfold/data/templates.py
+2
-2
openfold/model/primitives.py
openfold/model/primitives.py
+2
-2
setup.py
setup.py
+1
-1
No files found.
.gitignore
View file @
e71c1b14
...
...
@@ -9,4 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass
environment.yml
View file @
e71c1b14
...
...
@@ -3,6 +3,7 @@ channels:
-
conda-forge
-
bioconda
-
pytorch
-
nvidia
dependencies
:
-
python=3.9
-
libgcc=7.2
...
...
@@ -10,17 +11,16 @@ dependencies:
-
pip
-
openmm=7.7
-
pdbfixer
-
cudatoolkit==11.3.*
-
pytorch-lightning==1.5.10
-
pytorch-lightning
-
biopython==1.79
-
numpy
==1.21
-
pandas
==2.0
-
numpy
-
pandas
-
PyYAML==5.4.1
-
requests
-
scipy
==1.7
-
scipy
-
tqdm==4.62.2
-
typing-extensions
==3.10
-
wandb
==0.12.21
-
typing-extensions
-
wandb
-
modelcif==0.7
-
awscli
-
ml-collections
...
...
@@ -29,9 +29,10 @@ dependencies:
-
bioconda::hmmer==3.3.2
-
bioconda::hhsuite==3.3.0
-
bioconda::kalign2==2.04
-
pytorch::pytorch=1.12.*
-
pytorch::pytorch=2.1
-
pytorch::pytorch-cuda=12.1
-
pip
:
-
deepspeed==0.12.4
-
dm-tree==0.1.6
-
git+https://github.com/NVIDIA/dllogger.git
-
git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
-
flash-attn
openfold/data/data_pipeline.py
View file @
e71c1b14
...
...
@@ -110,12 +110,12 @@ def make_sequence_features(
)
features
[
"between_segment_residues"
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
"domain_name"
]
=
np
.
array
(
[
description
.
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
description
.
encode
(
"utf-8"
)],
dtype
=
object
)
features
[
"residue_index"
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
"seq_length"
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"sequence"
]
=
np
.
array
(
[
sequence
.
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
sequence
.
encode
(
"utf-8"
)],
dtype
=
object
)
return
features
...
...
@@ -148,7 +148,7 @@ def make_mmcif_features(
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
object
)
mmcif_feats
[
"is_distillation"
]
=
np
.
array
(
0.
,
dtype
=
np
.
float32
)
...
...
@@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features
[
"num_alignments"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object
_
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
object
)
return
features
...
...
@@ -593,7 +593,7 @@ def convert_monomer_features(
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object
_
)
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
object
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
...
...
@@ -1290,7 +1290,7 @@ class DataPipelineMultimer:
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object
_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
object
)
mmcif_feats
[
"is_distillation"
]
=
np
.
array
(
0.
,
dtype
=
np
.
float32
)
...
...
openfold/data/templates.py
View file @
e71c1b14
...
...
@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = {
"template_aatype"
:
np
.
int64
,
"template_all_atom_mask"
:
np
.
float32
,
"template_all_atom_positions"
:
np
.
float32
,
"template_domain_names"
:
np
.
object
,
"template_sequence"
:
np
.
object
,
"template_domain_names"
:
object
,
"template_sequence"
:
object
,
"template_sum_probs"
:
np
.
float32
,
}
...
...
openfold/model/primitives.py
View file @
e71c1b14
...
...
@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
fa_is_installed
:
from
flash_attn.bert_padding
import
unpad_input
from
flash_attn.flash_attn_interface
import
flash_attn_
unpadded
_kvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_
varlen
_kvpacked_func
import
torch
import
torch.nn
as
nn
...
...
@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask):
kv_unpad
,
_
,
kv_cu_seqlens
,
kv_max_s
=
unpad_input
(
kv
,
kv_mask
)
kv_unpad
=
kv_unpad
.
reshape
(
-
1
,
*
kv_shape
[
-
3
:])
out
=
flash_attn_
unpadded
_kvpacked_func
(
out
=
flash_attn_
varlen
_kvpacked_func
(
q
,
kv_unpad
,
q_cu_seqlens
,
...
...
setup.py
View file @
e71c1b14
...
...
@@ -29,7 +29,7 @@ version_dependent_macros = [
]
extra_cuda_flags
=
[
'-std=c++1
4
'
,
'-std=c++1
7
'
,
'-maxrregcount=50'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment