"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "afa6adf1c133ac1b15b1a0ec0d9b5f26f59bb111"
Commit e71c1b14 authored by Jennifer's avatar Jennifer
Browse files

initial compatibility changes for upgrading multimer

parent 9a07b7f9
...@@ -9,4 +9,4 @@ dist ...@@ -9,4 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass
...@@ -3,6 +3,7 @@ channels: ...@@ -3,6 +3,7 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
- pytorch - pytorch
- nvidia
dependencies: dependencies:
- python=3.9 - python=3.9
- libgcc=7.2 - libgcc=7.2
...@@ -10,17 +11,16 @@ dependencies: ...@@ -10,17 +11,16 @@ dependencies:
- pip - pip
- openmm=7.7 - openmm=7.7
- pdbfixer - pdbfixer
- cudatoolkit==11.3.* - pytorch-lightning
- pytorch-lightning==1.5.10
- biopython==1.79 - biopython==1.79
- numpy==1.21 - numpy
- pandas==2.0 - pandas
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests - requests
- scipy==1.7 - scipy
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10 - typing-extensions
- wandb==0.12.21 - wandb
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
- ml-collections - ml-collections
...@@ -29,9 +29,10 @@ dependencies: ...@@ -29,9 +29,10 @@ dependencies:
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=2.1
- pytorch::pytorch-cuda=12.1
- pip: - pip:
- deepspeed==0.12.4 - deepspeed==0.12.4
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - flash-attn
...@@ -110,12 +110,12 @@ def make_sequence_features( ...@@ -110,12 +110,12 @@ def make_sequence_features(
) )
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array( features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_ [description.encode("utf-8")], dtype=object
) )
features["residue_index"] = np.array(range(num_res), dtype=np.int32) features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array( features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_ [sequence.encode("utf-8")], dtype=object
) )
return features return features
...@@ -148,7 +148,7 @@ def make_mmcif_features( ...@@ -148,7 +148,7 @@ def make_mmcif_features(
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
...@@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features return features
...@@ -593,7 +593,7 @@ def convert_monomer_features( ...@@ -593,7 +593,7 @@ def convert_monomer_features(
) -> FeatureDict: ) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models.""" """Reshapes and modifies monomer features for multimer models."""
converted = {} converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = { unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length' 'sequence', 'domain_name', 'num_alignments', 'seq_length'
} }
...@@ -1290,7 +1290,7 @@ class DataPipelineMultimer: ...@@ -1290,7 +1290,7 @@ class DataPipelineMultimer:
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
...@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = { ...@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = {
"template_aatype": np.int64, "template_aatype": np.int64,
"template_all_atom_mask": np.float32, "template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32, "template_all_atom_positions": np.float32,
"template_domain_names": np.object, "template_domain_names": object,
"template_sequence": np.object, "template_sequence": object,
"template_sum_probs": np.float32, "template_sum_probs": np.float32,
} }
......
...@@ -28,7 +28,7 @@ if ds4s_is_installed: ...@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
from flash_attn.bert_padding import unpad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask):
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func( out = flash_attn_varlen_kvpacked_func(
q, q,
kv_unpad, kv_unpad,
q_cu_seqlens, q_cu_seqlens,
......
...@@ -29,7 +29,7 @@ version_dependent_macros = [ ...@@ -29,7 +29,7 @@ version_dependent_macros = [
] ]
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-std=c++17',
'-maxrregcount=50', '-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment