Commit 89a6875f authored by Rich Evans's avatar Rich Evans Committed by Copybara-Service
Browse files

Remove prokaryotic pairing from AFOS.

PiperOrigin-RevId: 429521382
Change-Id: Ib9bfce50834bb129af0dbd3508cd0b92856a9bb1
parent b787ca29
......@@ -201,6 +201,11 @@ change the following:
`alphafold/model/config.py`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`.
### API changes between v2.1.0 and v2.2.0
The `--is_prokaryote_list` flag has been removed along with the `is_prokaryote`
argument in `run_alphafold.predict_structure()`, eukaryotes and prokaryotes are
now paired in the same way.
## Running AlphaFold
......@@ -303,16 +308,12 @@ All steps are the same as when running the monomer system, but you will have to
* provide an input fasta with multiple sequences,
* set `--model_preset=multimer`,
* optionally set the `--is_prokaryote_list` flag with booleans that determine
whether all input sequences in the given fasta file are prokaryotic. If that
is not the case or the origin is unknown, set to `false` for that fasta.
An example that folds a protein complex `multimer.fasta` that is prokaryotic:
An example that folds a protein complex `multimer.fasta`:
```bash
python3 docker/run_docker.py \
--fasta_paths=multimer.fasta \
--is_prokaryote_list=true \
--max_template_date=2020-05-14 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
......@@ -348,7 +349,7 @@ python3 docker/run_docker.py \
#### Folding a homomer
Say we have a homomer from a prokaryote with 3 copies of the same sequence
Say we have a homomer with 3 copies of the same sequence
`<SEQUENCE>`. The input fasta should be:
```fasta
......@@ -365,7 +366,6 @@ Then run the following command:
```bash
python3 docker/run_docker.py \
--fasta_paths=homomer.fasta \
--is_prokaryote_list=true \
--max_template_date=2021-11-01 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
......@@ -373,7 +373,7 @@ python3 docker/run_docker.py \
#### Folding a heteromer
Say we have a heteromer A2B3 of unknown origin, i.e. with 2 copies of
Say we have an A2B3 heteromer, i.e. with 2 copies of
`<SEQUENCE A>` and 3 copies of `<SEQUENCE B>`. The input fasta should be:
```fasta
......@@ -394,7 +394,6 @@ Then run the following command:
```bash
python3 docker/run_docker.py \
--fasta_paths=heteromer.fasta \
--is_prokaryote_list=false \
--max_template_date=2021-11-01 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
......@@ -416,15 +415,13 @@ python3 docker/run_docker.py \
#### Folding multiple multimers one after another
Say we have a two multimers, `multimer1.fasta` and `multimer2.fasta`. Both are
from a prokaryotic organism.
Say we have a two multimers, `multimer1.fasta` and `multimer2.fasta`.
We can fold both sequentially by using the following command:
```bash
python3 docker/run_docker.py \
--fasta_paths=multimer1.fasta,multimer2.fasta \
--is_prokaryote_list=true,true \
--max_template_date=2021-11-01 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
......
......@@ -46,14 +46,12 @@ def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
is_prokaryote: bool) -> pipeline.FeatureDict:
all_chain_features: MutableMapping[str, pipeline.FeatureDict]
) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
A dictionary of features.
......@@ -67,7 +65,7 @@ def pair_and_merge(
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list, prokaryotic=is_prokaryote)
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
......
......@@ -48,12 +48,11 @@ _UNIPROT_PATTERN = re.compile(
@dataclasses.dataclass(frozen=True)
class Identifiers:
uniprot_accession_id: str = ''
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
"""Gets species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
......@@ -63,13 +62,12 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
uniprot_accession_id=matches.group('AccessionIdentifier'),
species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
......
......@@ -25,12 +25,6 @@ import numpy as np
import pandas as pd
import scipy.linalg
ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
ALPHANUM_ACCESSION_ID_MAP = {
chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
} # A-Z,0-9
NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
......@@ -58,15 +52,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[pipeline.FeatureDict],
prokaryotic: bool,
) -> List[pipeline.FeatureDict]:
chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns:
A list of feature dictionaries with sequence features including only
......@@ -79,8 +69,7 @@ def create_paired_features(
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(
chains, prokaryotic)
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
......@@ -115,8 +104,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name in ('msa_uniprot_accession_identifiers_all_seq',
'msa_species_identifiers_all_seq'):
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
......@@ -134,11 +122,9 @@ def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_uniprot_accession_identifiers':
chain_features['msa_uniprot_accession_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_uniprot_accession_identifiers_all_seq'])),
chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
......@@ -153,139 +139,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return species_lookup
@functools.lru_cache(maxsize=65536)
def encode_accession(accession_id: str) -> int:
"""Map accession codes to the serial order in which they were assigned."""
alpha = ALPHA_ACCESSION_ID_MAP # A-Z
alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9
num = NUM_ACCESSION_ID_MAP # 0-9
coding = 0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if accession_id[0] in {'O', 'P', 'Q'}:
bases = (alpha, num, alphanum, alphanum, alphanum, num)
elif len(accession_id) == 6:
bases = (alpha, num, alpha, alphanum, alphanum, num)
elif len(accession_id) == 10:
bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
alphanum, num)
product = 1
for place, base in zip(reversed(accession_id), reversed(bases)):
coding += base[place] * product
product *= len(base)
return coding
def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
diff_cutoff: int = 20
) -> List[List[Any]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples = []
current_tuple = []
tokens_used_in_answer = set()
def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
def _all_tokens_not_used_before() -> bool:
return all((s not in tokens_used_in_answer for s in current_tuple))
def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
if level == len(accession_id_lists) - 1:
if _all_tokens_not_used_before():
all_accession_tuples.append(list(current_tuple))
for s in current_tuple:
tokens_used_in_answer.add(s)
return
if level == -1:
new_list = accession_id_lists[level+1]
else:
new_list = [(_calc_id_diff(accession_id, s), s) for
s in accession_id_lists[level+1]]
new_list = sorted(new_list)
new_list = [s for d, s in new_list]
for s in new_list:
if (_matches_all_in_current_tuple(s, diff_cutoff) and
s not in tokens_used_in_answer):
current_tuple.append(s)
dfs(level + 1, s)
current_tuple.pop()
dfs(-1, '')
return all_accession_tuples
def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
return matched_df.iloc[0]
def _match_rows_by_genetic_distance(
this_species_msa_dfs: List[pd.DataFrame],
cutoff: int = 20) -> List[List[int]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples = len(this_species_msa_dfs) # N
accession_id_lists = [] # M
match_index_to_chain_index = {}
for chain_index, species_df in enumerate(this_species_msa_dfs):
if species_df is not None:
accession_id_lists.append(
list(species_df.msa_uniprot_accession_identifiers.values))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
all_accession_id_matches = _find_all_accession_matches(
accession_id_lists, cutoff) # [k, M]
all_paired_msa_rows = [] # [k, N]
for accession_id_match in all_accession_id_matches:
paired_msa_rows = []
for match_index, accession_id in enumerate(accession_id_match):
# Map back to chain index.
chain_index = match_index_to_chain_index[match_index]
seq_series = _accession_row(
this_species_msa_dfs[chain_index], accession_id)
if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
seq_series.gap > SEQUENCE_GAP_CUTOFF):
continue
else:
paired_msa_rows.append(seq_series.msa_row)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if len(paired_msa_rows) == len(accession_id_match):
paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
matched_chain_indices = list(match_index_to_chain_index.values())
paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
return all_paired_msa_rows
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
......@@ -322,8 +175,8 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict],
prokaryotic: bool) -> Dict[int, np.ndarray]:
def pair_sequences(examples: List[pipeline.FeatureDict]
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
......@@ -365,23 +218,7 @@ def pair_sequences(examples: List[pipeline.FeatureDict],
isinstance(species_df, pd.DataFrame)]) > 600):
continue
# In prokaryotes (and some eukaryotes), interacting genes are often
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if prokaryotic:
paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
if not paired_msa_rows:
continue
else:
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
......
......@@ -57,7 +57,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
int_msa = []
deletion_matrix = []
uniprot_accession_ids = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
......@@ -72,8 +71,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index])
uniprot_accession_ids.append(
identifiers.uniprot_accession_id.encode('utf-8'))
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
......@@ -83,8 +80,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32)
features['msa_uniprot_accession_identifiers'] = np.array(
uniprot_accession_ids, dtype=np.object_)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features
......
......@@ -231,7 +231,6 @@ class DataPipeline:
msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers',
)
feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
......@@ -240,8 +239,7 @@ class DataPipeline:
def process(self,
input_fasta_path: str,
msa_output_dir: str,
is_prokaryote: bool = False) -> pipeline.FeatureDict:
msa_output_dir: str) -> pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
......@@ -278,9 +276,7 @@ class DataPipeline:
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features,
is_prokaryote=is_prokaryote,
)
all_chain_features=all_chain_features)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
......
......@@ -45,12 +45,6 @@ flags.DEFINE_list(
'multiple sequences, then it will be folded as a multimer. Paths should be '
'separated by commas. All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for each prediction.')
flags.DEFINE_list(
'is_prokaryote_list', None, 'Optional for multimer system, not used by the '
'single chain system. This list should contain a boolean for each fasta '
'specifying true where the target complex is from a prokaryote, and false '
'where it is not, or where the origin is unknown. These values determine '
'the pairing method for the MSA.')
flags.DEFINE_string(
'output_dir', '/tmp/alphafold',
'Path to a directory that will store the results.')
......@@ -224,10 +218,6 @@ def main(argv):
'--logtostderr',
])
if FLAGS.is_prokaryote_list:
command_args.append(
f'--is_prokaryote_list={",".join(FLAGS.is_prokaryote_list)}')
client = docker.from_env()
container = client.containers.run(
image=FLAGS.docker_image_name,
......
......@@ -231,12 +231,6 @@
"input_sequences = (sequence_1, sequence_2, sequence_3, sequence_4,\n",
" sequence_5, sequence_6, sequence_7, sequence_8)\n",
"\n",
"#@markdown If folding a complex target and all the input sequences are\n",
"#@markdown prokaryotic then set `is_prokaryotic` to `True`. Set to `False`\n",
"#@markdown otherwise or if the origin is unknown.\n",
"\n",
"is_prokaryote = False #@param {type:\"boolean\"}\n",
"\n",
"MIN_SINGLE_SEQUENCE_LENGTH = 16\n",
"MAX_SINGLE_SEQUENCE_LENGTH = 2500\n",
"MAX_MULTIMER_LENGTH = 2500\n",
......@@ -426,7 +420,6 @@
" # Construct the all_seq features only for heteromers, not homomers.\n",
" if model_type_to_use == notebook_utils.ModelType.MULTIMER and len(set(sequences)) \u003e 1:\n",
" valid_feats = msa_pairing.MSA_FEATURES + (\n",
" 'msa_uniprot_accession_identifiers',\n",
" 'msa_species_identifiers',\n",
" )\n",
" all_seq_features = {\n",
......@@ -450,7 +443,7 @@
" all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)\n",
"\n",
" np_example = feature_processing.pair_and_merge(\n",
" all_chain_features=all_chain_features, is_prokaryote=is_prokaryote)\n",
" all_chain_features=all_chain_features)\n",
"\n",
" # Pad MSA to avoid zero-sized extra_msa.\n",
" np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)"
......
......@@ -49,12 +49,6 @@ flags.DEFINE_list(
'multiple sequences, then it will be folded as a multimer. Paths should be '
'separated by commas. All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for each prediction.')
flags.DEFINE_list(
'is_prokaryote_list', None, 'Optional for multimer system, not used by the '
'single chain system. This list should contain a boolean for each fasta '
'specifying true where the target complex is from a prokaryote, and false '
'where it is not, or where the origin is unknown. These values determine '
'the pairing method for the MSA.')
flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
......@@ -162,8 +156,7 @@ def predict_structure(
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
benchmark: bool,
random_seed: int,
is_prokaryote: Optional[bool] = None):
random_seed: int):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {}
......@@ -176,15 +169,9 @@ def predict_structure(
# Get features.
t_0 = time.time()
if is_prokaryote is None:
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
else:
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir,
is_prokaryote=is_prokaryote)
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
timings['features'] = time.time() - t_0
# Write out features as a pickled dictionary.
......@@ -324,22 +311,6 @@ def main(argv):
if len(fasta_names) != len(set(fasta_names)):
raise ValueError('All FASTA paths must have a unique basename.')
# Check that is_prokaryote_list has same number of elements as fasta_paths,
# and convert to bool.
if FLAGS.is_prokaryote_list:
if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths):
raise ValueError('--is_prokaryote_list must either be omitted or match '
'length of --fasta_paths.')
is_prokaryote_list = []
for s in FLAGS.is_prokaryote_list:
if s in ('true', 'false'):
is_prokaryote_list.append(s == 'true')
else:
raise ValueError('--is_prokaryote_list must contain comma separated '
'true or false values.')
else: # Default is_prokaryote to False.
is_prokaryote_list = [False] * len(fasta_names)
if run_multimer_system:
template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path,
......@@ -423,7 +394,6 @@ def main(argv):
# Predict structure for each of the sequences.
for i, fasta_path in enumerate(FLAGS.fasta_paths):
is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None
fasta_name = fasta_names[i]
predict_structure(
fasta_path=fasta_path,
......@@ -433,8 +403,7 @@ def main(argv):
model_runners=model_runners,
amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark,
random_seed=random_seed,
is_prokaryote=is_prokaryote)
random_seed=random_seed)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment