Commit e71c1b14 authored by Jennifer's avatar Jennifer
Browse files

initial compatibility changes for upgrading multimer

parent 9a07b7f9
......@@ -9,4 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass
......@@ -3,6 +3,7 @@ channels:
- conda-forge
- bioconda
- pytorch
- nvidia
dependencies:
- python=3.9
- libgcc=7.2
......@@ -10,17 +11,16 @@ dependencies:
- pip
- openmm=7.7
- pdbfixer
- cudatoolkit==11.3.*
- pytorch-lightning==1.5.10
- pytorch-lightning
- biopython==1.79
- numpy==1.21
- pandas==2.0
- numpy
- pandas
- PyYAML==5.4.1
- requests
- scipy==1.7
- scipy
- tqdm==4.62.2
- typing-extensions==3.10
- wandb==0.12.21
- typing-extensions
- wandb
- modelcif==0.7
- awscli
- ml-collections
......@@ -29,9 +29,10 @@ dependencies:
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pytorch::pytorch=2.1
- pytorch::pytorch-cuda=12.1
- pip:
- deepspeed==0.12.4
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
- flash-attn
......@@ -110,12 +110,12 @@ def make_sequence_features(
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
[description.encode("utf-8")], dtype=object
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
[sequence.encode("utf-8")], dtype=object
)
return features
......@@ -148,7 +148,7 @@ def make_mmcif_features(
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......@@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features
......@@ -593,7 +593,7 @@ def convert_monomer_features(
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'
}
......@@ -1290,7 +1290,7 @@ class DataPipelineMultimer:
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
......@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = {
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_domain_names": object,
"template_sequence": object,
"template_sum_probs": np.float32,
}
......
......@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
import torch
import torch.nn as nn
......@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask):
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
out = flash_attn_varlen_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
......
......@@ -29,7 +29,7 @@ version_dependent_macros = [
]
extra_cuda_flags = [
'-std=c++14',
'-std=c++17',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment