Commit 89a6875f authored by Rich Evans's avatar Rich Evans Committed by Copybara-Service
Browse files

Remove prokaryotic pairing from AFOS.

PiperOrigin-RevId: 429521382
Change-Id: Ib9bfce50834bb129af0dbd3508cd0b92856a9bb1
parent b787ca29
...@@ -201,6 +201,11 @@ change the following: ...@@ -201,6 +201,11 @@ change the following:
`alphafold/model/config.py`. `alphafold/model/config.py`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`. * Setting the `data_dir` flag is now needed when using `run_docker.py`.
### API changes between v2.1.0 and v2.2.0
The `--is_prokaryote_list` flag has been removed along with the `is_prokaryote`
argument in `run_alphafold.predict_structure()`, eukaryotes and prokaryotes are
now paired in the same way.
## Running AlphaFold ## Running AlphaFold
...@@ -303,16 +308,12 @@ All steps are the same as when running the monomer system, but you will have to ...@@ -303,16 +308,12 @@ All steps are the same as when running the monomer system, but you will have to
* provide an input fasta with multiple sequences, * provide an input fasta with multiple sequences,
* set `--model_preset=multimer`, * set `--model_preset=multimer`,
* optionally set the `--is_prokaryote_list` flag with booleans that determine
whether all input sequences in the given fasta file are prokaryotic. If that
is not the case or the origin is unknown, set to `false` for that fasta.
An example that folds a protein complex `multimer.fasta` that is prokaryotic: An example that folds a protein complex `multimer.fasta`:
```bash ```bash
python3 docker/run_docker.py \ python3 docker/run_docker.py \
--fasta_paths=multimer.fasta \ --fasta_paths=multimer.fasta \
--is_prokaryote_list=true \
--max_template_date=2020-05-14 \ --max_template_date=2020-05-14 \
--model_preset=multimer \ --model_preset=multimer \
--data_dir=$DOWNLOAD_DIR --data_dir=$DOWNLOAD_DIR
...@@ -348,7 +349,7 @@ python3 docker/run_docker.py \ ...@@ -348,7 +349,7 @@ python3 docker/run_docker.py \
#### Folding a homomer #### Folding a homomer
Say we have a homomer from a prokaryote with 3 copies of the same sequence Say we have a homomer with 3 copies of the same sequence
`<SEQUENCE>`. The input fasta should be: `<SEQUENCE>`. The input fasta should be:
```fasta ```fasta
...@@ -365,7 +366,6 @@ Then run the following command: ...@@ -365,7 +366,6 @@ Then run the following command:
```bash ```bash
python3 docker/run_docker.py \ python3 docker/run_docker.py \
--fasta_paths=homomer.fasta \ --fasta_paths=homomer.fasta \
--is_prokaryote_list=true \
--max_template_date=2021-11-01 \ --max_template_date=2021-11-01 \
--model_preset=multimer \ --model_preset=multimer \
--data_dir=$DOWNLOAD_DIR --data_dir=$DOWNLOAD_DIR
...@@ -373,7 +373,7 @@ python3 docker/run_docker.py \ ...@@ -373,7 +373,7 @@ python3 docker/run_docker.py \
#### Folding a heteromer #### Folding a heteromer
Say we have a heteromer A2B3 of unknown origin, i.e. with 2 copies of Say we have an A2B3 heteromer, i.e. with 2 copies of
`<SEQUENCE A>` and 3 copies of `<SEQUENCE B>`. The input fasta should be: `<SEQUENCE A>` and 3 copies of `<SEQUENCE B>`. The input fasta should be:
```fasta ```fasta
...@@ -394,7 +394,6 @@ Then run the following command: ...@@ -394,7 +394,6 @@ Then run the following command:
```bash ```bash
python3 docker/run_docker.py \ python3 docker/run_docker.py \
--fasta_paths=heteromer.fasta \ --fasta_paths=heteromer.fasta \
--is_prokaryote_list=false \
--max_template_date=2021-11-01 \ --max_template_date=2021-11-01 \
--model_preset=multimer \ --model_preset=multimer \
--data_dir=$DOWNLOAD_DIR --data_dir=$DOWNLOAD_DIR
...@@ -416,15 +415,13 @@ python3 docker/run_docker.py \ ...@@ -416,15 +415,13 @@ python3 docker/run_docker.py \
#### Folding multiple multimers one after another #### Folding multiple multimers one after another
Say we have a two multimers, `multimer1.fasta` and `multimer2.fasta`. Both are Say we have a two multimers, `multimer1.fasta` and `multimer2.fasta`.
from a prokaryotic organism.
We can fold both sequentially by using the following command: We can fold both sequentially by using the following command:
```bash ```bash
python3 docker/run_docker.py \ python3 docker/run_docker.py \
--fasta_paths=multimer1.fasta,multimer2.fasta \ --fasta_paths=multimer1.fasta,multimer2.fasta \
--is_prokaryote_list=true,true \
--max_template_date=2021-11-01 \ --max_template_date=2021-11-01 \
--model_preset=multimer \ --model_preset=multimer \
--data_dir=$DOWNLOAD_DIR --data_dir=$DOWNLOAD_DIR
......
...@@ -46,14 +46,12 @@ def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool: ...@@ -46,14 +46,12 @@ def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
def pair_and_merge( def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict], all_chain_features: MutableMapping[str, pipeline.FeatureDict]
is_prokaryote: bool) -> pipeline.FeatureDict: ) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge. """Runs processing on features to augment, pair and merge.
Args: Args:
all_chain_features: A MutableMap of dictionaries of features for each chain. all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns: Returns:
A dictionary of features. A dictionary of features.
...@@ -67,7 +65,7 @@ def pair_and_merge( ...@@ -67,7 +65,7 @@ def pair_and_merge(
if pair_msa_sequences: if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features( np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list, prokaryotic=is_prokaryote) chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list) np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains( np_chains_list = crop_chains(
np_chains_list, np_chains_list,
......
...@@ -48,12 +48,11 @@ _UNIPROT_PATTERN = re.compile( ...@@ -48,12 +48,11 @@ _UNIPROT_PATTERN = re.compile(
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Identifiers: class Identifiers:
uniprot_accession_id: str = ''
species_id: str = '' species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier. """Gets species from an msa sequence identifier.
The sequence identifier has the format specified by The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
...@@ -63,13 +62,12 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: ...@@ -63,13 +62,12 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
msa_sequence_identifier: a sequence identifier. msa_sequence_identifier: a sequence identifier.
Returns: Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found. can be empty in the case where no identifier was found.
""" """
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches: if matches:
return Identifiers( return Identifiers(
uniprot_accession_id=matches.group('AccessionIdentifier'),
species_id=matches.group('SpeciesIdentifier')) species_id=matches.group('SpeciesIdentifier'))
return Identifiers() return Identifiers()
......
...@@ -25,12 +25,6 @@ import numpy as np ...@@ -25,12 +25,6 @@ import numpy as np
import pandas as pd import pandas as pd
import scipy.linalg import scipy.linalg
ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
ALPHANUM_ACCESSION_ID_MAP = {
chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
} # A-Z,0-9
NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-') MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5 SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9 SEQUENCE_SIMILARITY_CUTOFF = 0.9
...@@ -58,15 +52,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length') ...@@ -58,15 +52,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features( def create_paired_features(
chains: Iterable[pipeline.FeatureDict], chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
prokaryotic: bool,
) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features. """Returns the original chains with paired NUM_SEQ features.
Args: Args:
chains: A list of feature dictionaries for each chain. chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns: Returns:
A list of feature dictionaries with sequence features including only A list of feature dictionaries with sequence features including only
...@@ -79,8 +69,7 @@ def create_paired_features( ...@@ -79,8 +69,7 @@ def create_paired_features(
return chains return chains
else: else:
updated_chains = [] updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences( paired_chains_to_paired_row_indices = pair_sequences(chains)
chains, prokaryotic)
paired_rows = reorder_paired_rows( paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices) paired_chains_to_paired_row_indices)
...@@ -115,8 +104,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: ...@@ -115,8 +104,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
num_res = feature.shape[1] num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype) feature.dtype)
elif feature_name in ('msa_uniprot_accession_identifiers_all_seq', elif feature_name == 'msa_species_identifiers_all_seq':
'msa_species_identifiers_all_seq'):
padding = [b''] padding = [b'']
else: else:
return feature return feature
...@@ -134,11 +122,9 @@ def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame: ...@@ -134,11 +122,9 @@ def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
msa_df = pd.DataFrame({ msa_df = pd.DataFrame({
'msa_species_identifiers': 'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'], chain_features['msa_species_identifiers_all_seq'],
'msa_uniprot_accession_identifiers':
chain_features['msa_uniprot_accession_identifiers_all_seq'],
'msa_row': 'msa_row':
np.arange(len( np.arange(len(
chain_features['msa_uniprot_accession_identifiers_all_seq'])), chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity, 'msa_similarity': per_seq_similarity,
'gap': per_seq_gap 'gap': per_seq_gap
}) })
...@@ -153,139 +139,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: ...@@ -153,139 +139,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return species_lookup return species_lookup
@functools.lru_cache(maxsize=65536)
def encode_accession(accession_id: str) -> int:
"""Map accession codes to the serial order in which they were assigned."""
alpha = ALPHA_ACCESSION_ID_MAP # A-Z
alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9
num = NUM_ACCESSION_ID_MAP # 0-9
coding = 0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if accession_id[0] in {'O', 'P', 'Q'}:
bases = (alpha, num, alphanum, alphanum, alphanum, num)
elif len(accession_id) == 6:
bases = (alpha, num, alpha, alphanum, alphanum, num)
elif len(accession_id) == 10:
bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
alphanum, num)
product = 1
for place, base in zip(reversed(accession_id), reversed(bases)):
coding += base[place] * product
product *= len(base)
return coding
def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
diff_cutoff: int = 20
) -> List[List[Any]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples = []
current_tuple = []
tokens_used_in_answer = set()
def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
def _all_tokens_not_used_before() -> bool:
return all((s not in tokens_used_in_answer for s in current_tuple))
def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
if level == len(accession_id_lists) - 1:
if _all_tokens_not_used_before():
all_accession_tuples.append(list(current_tuple))
for s in current_tuple:
tokens_used_in_answer.add(s)
return
if level == -1:
new_list = accession_id_lists[level+1]
else:
new_list = [(_calc_id_diff(accession_id, s), s) for
s in accession_id_lists[level+1]]
new_list = sorted(new_list)
new_list = [s for d, s in new_list]
for s in new_list:
if (_matches_all_in_current_tuple(s, diff_cutoff) and
s not in tokens_used_in_answer):
current_tuple.append(s)
dfs(level + 1, s)
current_tuple.pop()
dfs(-1, '')
return all_accession_tuples
def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
return matched_df.iloc[0]
def _match_rows_by_genetic_distance(
this_species_msa_dfs: List[pd.DataFrame],
cutoff: int = 20) -> List[List[int]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples = len(this_species_msa_dfs) # N
accession_id_lists = [] # M
match_index_to_chain_index = {}
for chain_index, species_df in enumerate(this_species_msa_dfs):
if species_df is not None:
accession_id_lists.append(
list(species_df.msa_uniprot_accession_identifiers.values))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
all_accession_id_matches = _find_all_accession_matches(
accession_id_lists, cutoff) # [k, M]
all_paired_msa_rows = [] # [k, N]
for accession_id_match in all_accession_id_matches:
paired_msa_rows = []
for match_index, accession_id in enumerate(accession_id_match):
# Map back to chain index.
chain_index = match_index_to_chain_index[match_index]
seq_series = _accession_row(
this_species_msa_dfs[chain_index], accession_id)
if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
seq_series.gap > SEQUENCE_GAP_CUTOFF):
continue
else:
paired_msa_rows.append(seq_series.msa_row)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if len(paired_msa_rows) == len(accession_id_match):
paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
matched_chain_indices = list(match_index_to_chain_index.values())
paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
return all_paired_msa_rows
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]: ) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity. """Finds MSA sequence pairings across chains based on sequence similarity.
...@@ -322,8 +175,8 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] ...@@ -322,8 +175,8 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return all_paired_msa_rows return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict], def pair_sequences(examples: List[pipeline.FeatureDict]
prokaryotic: bool) -> Dict[int, np.ndarray]: ) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains.""" """Returns indices for paired MSA sequences across chains."""
num_examples = len(examples) num_examples = len(examples)
...@@ -365,22 +218,6 @@ def pair_sequences(examples: List[pipeline.FeatureDict], ...@@ -365,22 +218,6 @@ def pair_sequences(examples: List[pipeline.FeatureDict],
isinstance(species_df, pd.DataFrame)]) > 600): isinstance(species_df, pd.DataFrame)]) > 600):
continue continue
# In prokaryotes (and some eukaryotes), interacting genes are often
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if prokaryotic:
paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
if not paired_msa_rows:
continue
else:
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows) all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
......
...@@ -57,7 +57,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -57,7 +57,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
uniprot_accession_ids = []
species_ids = [] species_ids = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
...@@ -72,8 +71,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -72,8 +71,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
deletion_matrix.append(msa.deletion_matrix[sequence_index]) deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers( identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]) msa.descriptions[sequence_index])
uniprot_accession_ids.append(
identifiers.uniprot_accession_id.encode('utf-8'))
species_ids.append(identifiers.species_id.encode('utf-8')) species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0]) num_res = len(msas[0].sequences[0])
...@@ -83,8 +80,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -83,8 +80,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features['msa'] = np.array(int_msa, dtype=np.int32) features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array( features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32) [num_alignments] * num_res, dtype=np.int32)
features['msa_uniprot_accession_identifiers'] = np.array(
uniprot_accession_ids, dtype=np.object_)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_) features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features return features
......
...@@ -231,7 +231,6 @@ class DataPipeline: ...@@ -231,7 +231,6 @@ class DataPipeline:
msa = msa.truncate(max_seqs=self._max_uniprot_hits) msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa]) all_seq_features = pipeline.make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers', 'msa_species_identifiers',
) )
feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
...@@ -240,8 +239,7 @@ class DataPipeline: ...@@ -240,8 +239,7 @@ class DataPipeline:
def process(self, def process(self,
input_fasta_path: str, input_fasta_path: str,
msa_output_dir: str, msa_output_dir: str) -> pipeline.FeatureDict:
is_prokaryote: bool = False) -> pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features.""" """Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f: with open(input_fasta_path) as f:
input_fasta_str = f.read() input_fasta_str = f.read()
...@@ -278,9 +276,7 @@ class DataPipeline: ...@@ -278,9 +276,7 @@ class DataPipeline:
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge( np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features)
is_prokaryote=is_prokaryote,
)
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
......
...@@ -45,12 +45,6 @@ flags.DEFINE_list( ...@@ -45,12 +45,6 @@ flags.DEFINE_list(
'multiple sequences, then it will be folded as a multimer. Paths should be ' 'multiple sequences, then it will be folded as a multimer. Paths should be '
'separated by commas. All FASTA paths must have a unique basename as the ' 'separated by commas. All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for each prediction.') 'basename is used to name the output directories for each prediction.')
flags.DEFINE_list(
'is_prokaryote_list', None, 'Optional for multimer system, not used by the '
'single chain system. This list should contain a boolean for each fasta '
'specifying true where the target complex is from a prokaryote, and false '
'where it is not, or where the origin is unknown. These values determine '
'the pairing method for the MSA.')
flags.DEFINE_string( flags.DEFINE_string(
'output_dir', '/tmp/alphafold', 'output_dir', '/tmp/alphafold',
'Path to a directory that will store the results.') 'Path to a directory that will store the results.')
...@@ -224,10 +218,6 @@ def main(argv): ...@@ -224,10 +218,6 @@ def main(argv):
'--logtostderr', '--logtostderr',
]) ])
if FLAGS.is_prokaryote_list:
command_args.append(
f'--is_prokaryote_list={",".join(FLAGS.is_prokaryote_list)}')
client = docker.from_env() client = docker.from_env()
container = client.containers.run( container = client.containers.run(
image=FLAGS.docker_image_name, image=FLAGS.docker_image_name,
......
...@@ -231,12 +231,6 @@ ...@@ -231,12 +231,6 @@
"input_sequences = (sequence_1, sequence_2, sequence_3, sequence_4,\n", "input_sequences = (sequence_1, sequence_2, sequence_3, sequence_4,\n",
" sequence_5, sequence_6, sequence_7, sequence_8)\n", " sequence_5, sequence_6, sequence_7, sequence_8)\n",
"\n", "\n",
"#@markdown If folding a complex target and all the input sequences are\n",
"#@markdown prokaryotic then set `is_prokaryotic` to `True`. Set to `False`\n",
"#@markdown otherwise or if the origin is unknown.\n",
"\n",
"is_prokaryote = False #@param {type:\"boolean\"}\n",
"\n",
"MIN_SINGLE_SEQUENCE_LENGTH = 16\n", "MIN_SINGLE_SEQUENCE_LENGTH = 16\n",
"MAX_SINGLE_SEQUENCE_LENGTH = 2500\n", "MAX_SINGLE_SEQUENCE_LENGTH = 2500\n",
"MAX_MULTIMER_LENGTH = 2500\n", "MAX_MULTIMER_LENGTH = 2500\n",
...@@ -426,7 +420,6 @@ ...@@ -426,7 +420,6 @@
" # Construct the all_seq features only for heteromers, not homomers.\n", " # Construct the all_seq features only for heteromers, not homomers.\n",
" if model_type_to_use == notebook_utils.ModelType.MULTIMER and len(set(sequences)) \u003e 1:\n", " if model_type_to_use == notebook_utils.ModelType.MULTIMER and len(set(sequences)) \u003e 1:\n",
" valid_feats = msa_pairing.MSA_FEATURES + (\n", " valid_feats = msa_pairing.MSA_FEATURES + (\n",
" 'msa_uniprot_accession_identifiers',\n",
" 'msa_species_identifiers',\n", " 'msa_species_identifiers',\n",
" )\n", " )\n",
" all_seq_features = {\n", " all_seq_features = {\n",
...@@ -450,7 +443,7 @@ ...@@ -450,7 +443,7 @@
" all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)\n", " all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)\n",
"\n", "\n",
" np_example = feature_processing.pair_and_merge(\n", " np_example = feature_processing.pair_and_merge(\n",
" all_chain_features=all_chain_features, is_prokaryote=is_prokaryote)\n", " all_chain_features=all_chain_features)\n",
"\n", "\n",
" # Pad MSA to avoid zero-sized extra_msa.\n", " # Pad MSA to avoid zero-sized extra_msa.\n",
" np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)" " np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)"
......
...@@ -49,12 +49,6 @@ flags.DEFINE_list( ...@@ -49,12 +49,6 @@ flags.DEFINE_list(
'multiple sequences, then it will be folded as a multimer. Paths should be ' 'multiple sequences, then it will be folded as a multimer. Paths should be '
'separated by commas. All FASTA paths must have a unique basename as the ' 'separated by commas. All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for each prediction.') 'basename is used to name the output directories for each prediction.')
flags.DEFINE_list(
'is_prokaryote_list', None, 'Optional for multimer system, not used by the '
'single chain system. This list should contain a boolean for each fasta '
'specifying true where the target complex is from a prokaryote, and false '
'where it is not, or where the origin is unknown. These values determine '
'the pairing method for the MSA.')
flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
...@@ -162,8 +156,7 @@ def predict_structure( ...@@ -162,8 +156,7 @@ def predict_structure(
model_runners: Dict[str, model.RunModel], model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation, amber_relaxer: relax.AmberRelaxation,
benchmark: bool, benchmark: bool,
random_seed: int, random_seed: int):
is_prokaryote: Optional[bool] = None):
"""Predicts structure using AlphaFold for the given sequence.""" """Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name) logging.info('Predicting %s', fasta_name)
timings = {} timings = {}
...@@ -176,15 +169,9 @@ def predict_structure( ...@@ -176,15 +169,9 @@ def predict_structure(
# Get features. # Get features.
t_0 = time.time() t_0 = time.time()
if is_prokaryote is None:
feature_dict = data_pipeline.process( feature_dict = data_pipeline.process(
input_fasta_path=fasta_path, input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir) msa_output_dir=msa_output_dir)
else:
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir,
is_prokaryote=is_prokaryote)
timings['features'] = time.time() - t_0 timings['features'] = time.time() - t_0
# Write out features as a pickled dictionary. # Write out features as a pickled dictionary.
...@@ -324,22 +311,6 @@ def main(argv): ...@@ -324,22 +311,6 @@ def main(argv):
if len(fasta_names) != len(set(fasta_names)): if len(fasta_names) != len(set(fasta_names)):
raise ValueError('All FASTA paths must have a unique basename.') raise ValueError('All FASTA paths must have a unique basename.')
# Check that is_prokaryote_list has same number of elements as fasta_paths,
# and convert to bool.
if FLAGS.is_prokaryote_list:
if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths):
raise ValueError('--is_prokaryote_list must either be omitted or match '
'length of --fasta_paths.')
is_prokaryote_list = []
for s in FLAGS.is_prokaryote_list:
if s in ('true', 'false'):
is_prokaryote_list.append(s == 'true')
else:
raise ValueError('--is_prokaryote_list must contain comma separated '
'true or false values.')
else: # Default is_prokaryote to False.
is_prokaryote_list = [False] * len(fasta_names)
if run_multimer_system: if run_multimer_system:
template_searcher = hmmsearch.Hmmsearch( template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path, binary_path=FLAGS.hmmsearch_binary_path,
...@@ -423,7 +394,6 @@ def main(argv): ...@@ -423,7 +394,6 @@ def main(argv):
# Predict structure for each of the sequences. # Predict structure for each of the sequences.
for i, fasta_path in enumerate(FLAGS.fasta_paths): for i, fasta_path in enumerate(FLAGS.fasta_paths):
is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None
fasta_name = fasta_names[i] fasta_name = fasta_names[i]
predict_structure( predict_structure(
fasta_path=fasta_path, fasta_path=fasta_path,
...@@ -433,8 +403,7 @@ def main(argv): ...@@ -433,8 +403,7 @@ def main(argv):
model_runners=model_runners, model_runners=model_runners,
amber_relaxer=amber_relaxer, amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark, benchmark=FLAGS.benchmark,
random_seed=random_seed, random_seed=random_seed)
is_prokaryote=is_prokaryote)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment