templates.py 39.7 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for getting templates and calculating template features."""
16
import abc
Tom Ward's avatar
Tom Ward committed
17
import dataclasses
Augustin-Zidek's avatar
Augustin-Zidek committed
18
import datetime
19
import functools
Augustin-Zidek's avatar
Augustin-Zidek committed
20
21
22
23
24
25
26
27
28
29
import glob
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple

from absl import logging
from alphafold.common import residue_constants
from alphafold.data import mmcif_parsing
from alphafold.data import parsers
from alphafold.data.tools import kalign
Tom Ward's avatar
Tom Ward committed
30
31
32
import numpy as np

# Internal import (7716).
Augustin-Zidek's avatar
Augustin-Zidek committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


class Error(Exception):
  """Base class for exceptions."""


class NoChainsError(Error):
  """An error indicating that template mmCIF didn't have any chains."""


class SequenceNotInTemplateError(Error):
  """An error indicating that template mmCIF didn't contain the sequence."""


class NoAtomDataInTemplateError(Error):
  """An error indicating that template mmCIF didn't contain atom positions."""


class TemplateAtomMaskAllZerosError(Error):
  """An error indicating that template mmCIF had all atom positions masked."""


class QueryToTemplateAlignError(Error):
  """An error indicating that the query can't be aligned to the template."""


class CaDistanceError(Error):
  """An error indicating that a CA atom distance exceeds a threshold."""


class MultipleChainsError(Error):
  """An error indicating that multiple chains were found for a given ID."""


# Prefilter exceptions.
class PrefilterError(Exception):
  """A base class for template prefilter exceptions."""


class DateError(PrefilterError):
  """An error indicating that the hit date was after the max allowed date."""


class AlignRatioError(PrefilterError):
  """An error indicating that the hit align ratio to the query was too small."""


class DuplicateError(PrefilterError):
  """An error indicating that the hit was an exact subsequence of the query."""


class LengthError(PrefilterError):
  """An error indicating that the hit was too short."""


TEMPLATE_FEATURES = {
    'template_aatype': np.float32,
    'template_all_atom_masks': np.float32,
    'template_all_atom_positions': np.float32,
    'template_domain_names': np.object,
    'template_sequence': np.object,
    'template_sum_probs': np.float32,
}


98
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
Augustin-Zidek's avatar
Augustin-Zidek committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  """Returns PDB id and chain id for an HHSearch Hit."""
  # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
  id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
  if not id_match:
    raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}')
  pdb_id, chain_id = id_match.group(0).split('_')
  return pdb_id.lower(), chain_id


def _is_after_cutoff(
    pdb_id: str,
    release_dates: Mapping[str, datetime.datetime],
    release_date_cutoff: Optional[datetime.datetime]) -> bool:
  """Checks if the template date is after the release date cutoff.

  Args:
    pdb_id: 4 letter pdb code.
    release_dates: Dictionary mapping PDB ids to their structure release dates.
    release_date_cutoff: Max release date that is valid for this query.

  Returns:
    True if the template release date is after the cutoff, False otherwise.
  """
  if release_date_cutoff is None:
    raise ValueError('The release_date_cutoff must not be None.')
  if pdb_id in release_dates:
    return release_dates[pdb_id] > release_date_cutoff
  else:
    # Since this is just a quick prefilter to reduce the number of mmCIF files
    # we need to parse, we don't have to worry about returning True here.
    return False


132
133
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]:
  """Parses the data file from PDB that lists which pdb_ids are obsolete."""
Augustin-Zidek's avatar
Augustin-Zidek committed
134
135
136
137
  with open(obsolete_file_path) as f:
    result = {}
    for line in f:
      line = line.strip()
138
139
140
141
142
143
144
145
146
147
148
149
150
151
      # Format:    Date      From     To
      # 'OBSLTE    06-NOV-19 6G9Y'                - Removed, rare
      # 'OBSLTE    31-JUL-94 116L     216L'       - Replaced, common
      # 'OBSLTE    26-SEP-06 2H33     2JM5 2OWI'  - Replaced by multiple, rare
      if line.startswith('OBSLTE'):
        if len(line) > 30:
          # Replaced by at least one structure.
          from_id = line[20:24].lower()
          to_id = line[29:33].lower()
          result[from_id] = to_id
        elif len(line) == 24:
          # Removed.
          from_id = line[20:24].lower()
          result[from_id] = None
Augustin-Zidek's avatar
Augustin-Zidek committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    return result


def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
  """Parses release dates file, returns a mapping from PDBs to release dates."""
  if path.endswith('txt'):
    release_dates = {}
    with open(path, 'r') as f:
      for line in f:
        pdb_id, date = line.split(':')
        date = date.strip()
        # Python 3.6 doesn't have datetime.date.fromisoformat() which is about
        # 90x faster than strptime. However, splitting the string manually is
        # about 10x faster than strptime.
        release_dates[pdb_id.strip()] = datetime.datetime(
            year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
    return release_dates
  else:
    raise ValueError('Invalid format of the release date file %s.' % path)


def _assess_hhsearch_hit(
174
    hit: parsers.TemplateHit,
Augustin-Zidek's avatar
Augustin-Zidek committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    hit_pdb_code: str,
    query_sequence: str,
    release_dates: Mapping[str, datetime.datetime],
    release_date_cutoff: datetime.datetime,
    max_subsequence_ratio: float = 0.95,
    min_align_ratio: float = 0.1) -> bool:
  """Determines if template is valid (without parsing the template mmcif file).

  Args:
    hit: HhrHit for the template.
    hit_pdb_code: The 4 letter pdb code of the template hit. This might be
      different from the value in the actual hit since the original pdb might
      have become obsolete.
    query_sequence: Amino acid sequence of the query.
    release_dates: Dictionary mapping pdb codes to their structure release
      dates.
    release_date_cutoff: Max release date that is valid for this query.
    max_subsequence_ratio: Exclude any exact matches with this much overlap.
    min_align_ratio: Minimum overlap between the template and query.

  Returns:
    True if the hit passed the prefilter. Raises an exception otherwise.

  Raises:
    DateError: If the hit date was after the max allowed date.
    AlignRatioError: If the hit align ratio to the query was too small.
    DuplicateError: If the hit was an exact subsequence of the query.
    LengthError: If the hit was too short.
  """
  aligned_cols = hit.aligned_cols
  align_ratio = aligned_cols / len(query_sequence)

  template_sequence = hit.hit_sequence.replace('-', '')
  length_ratio = float(len(template_sequence)) / len(query_sequence)

  # Check whether the template is a large subsequence or duplicate of original
  # query. This can happen due to duplicate entries in the PDB database.
  duplicate = (template_sequence in query_sequence and
               length_ratio > max_subsequence_ratio)

  if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
    raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
                    f'({release_date_cutoff}).')

  if align_ratio <= min_align_ratio:
    raise AlignRatioError('Proportion of residues aligned to query too small. '
                          f'Align ratio: {align_ratio}.')

  if duplicate:
    raise DuplicateError('Template is an exact subsequence of query with large '
                         f'coverage. Length ratio: {length_ratio}.')

  if len(template_sequence) < 10:
    raise LengthError(f'Template too short. Length: {len(template_sequence)}.')

  return True


def _find_template_in_pdb(
    template_chain_id: str,
    template_sequence: str,
    mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]:
  """Tries to find the template chain in the given pdb file.

  This method tries the three following things in order:
    1. Tries if there is an exact match in both the chain ID and the sequence.
       If yes, the chain sequence is returned. Otherwise:
    2. Tries if there is an exact match only in the sequence.
       If yes, the chain sequence is returned. Otherwise:
    3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
       If yes, the chain sequence is returned.
  If none of these succeed, a SequenceNotInTemplateError is thrown.

  Args:
    template_chain_id: The template chain ID.
    template_sequence: The template chain sequence.
    mmcif_object: The PDB object to search for the template in.

  Returns:
    A tuple with:
    * The chain sequence that was found to match the template in the PDB object.
    * The ID of the chain that is being returned.
    * The offset where the template sequence starts in the chain sequence.

  Raises:
    SequenceNotInTemplateError: If no match is found after the steps described
      above.
  """
  # Try if there is an exact match in both the chain ID and the (sub)sequence.
  pdb_id = mmcif_object.file_id
  chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
  if chain_sequence and (template_sequence in chain_sequence):
    logging.info(
        'Found an exact template match %s_%s.', pdb_id, template_chain_id)
    mapping_offset = chain_sequence.find(template_sequence)
    return chain_sequence, template_chain_id, mapping_offset

  # Try if there is an exact match in the (sub)sequence only.
  for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
    if chain_sequence and (template_sequence in chain_sequence):
      logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id)
      mapping_offset = chain_sequence.find(template_sequence)
      return chain_sequence, chain_id, mapping_offset

  # Return a chain sequence that fuzzy matches (X = wildcard) the template.
  # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
  regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
  regex = re.compile(''.join(regex))
  for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
    match = re.search(regex, chain_sequence)
    if match:
      logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id)
      mapping_offset = match.start()
      return chain_sequence, chain_id, mapping_offset

  # No hits, raise an error.
  raise SequenceNotInTemplateError(
      'Could not find the template sequence in %s_%s. Template sequence: %s, '
      'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
                               mmcif_object.chain_to_seqres))


def _realign_pdb_template_to_query(
    old_template_sequence: str,
    template_chain_id: str,
    mmcif_object: mmcif_parsing.MmcifObject,
    old_mapping: Mapping[int, int],
    kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]:
  """Aligns template from the mmcif_object to the query.

  In case PDB70 contains a different version of the template sequence, we need
  to perform a realignment to the actual sequence that is in the mmCIF file.
  This method performs such realignment, but returns the new sequence and
  mapping only if the sequence in the mmCIF file is 90% identical to the old
  sequence.

  Note that the old_template_sequence comes from the hit, and contains only that
  part of the chain that matches with the query while the new_template_sequence
  is the full chain.

  Args:
    old_template_sequence: The template sequence that was returned by the PDB
      template search (typically done using HHSearch).
    template_chain_id: The template chain id was returned by the PDB template
      search (typically done using HHSearch). This is used to find the right
      chain in the mmcif_object chain_to_seqres mapping.
    mmcif_object: A mmcif_object which holds the actual template data.
    old_mapping: A mapping from the query sequence to the template sequence.
      This mapping will be used to compute the new mapping from the query
      sequence to the actual mmcif_object template sequence by aligning the
      old_template_sequence and the actual template sequence.
    kalign_binary_path: The path to a kalign executable.

  Returns:
    A tuple (new_template_sequence, new_query_to_template_mapping) where:
    * new_template_sequence is the actual template sequence that was found in
      the mmcif_object.
    * new_query_to_template_mapping is the new mapping from the query to the
      actual template found in the mmcif_object.

  Raises:
    QueryToTemplateAlignError:
    * If there was an error thrown by the alignment tool.
    * Or if the actual template sequence differs by more than 10% from the
      old_template_sequence.
  """
  aligner = kalign.Kalign(binary_path=kalign_binary_path)
  new_template_sequence = mmcif_object.chain_to_seqres.get(
      template_chain_id, '')

  # Sometimes the template chain id is unknown. But if there is only a single
  # sequence within the mmcif_object, it is safe to assume it is that one.
  if not new_template_sequence:
    if len(mmcif_object.chain_to_seqres) == 1:
      logging.info('Could not find %s in %s, but there is only 1 sequence, so '
                   'using that one.',
                   template_chain_id,
                   mmcif_object.file_id)
      new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0]
    else:
      raise QueryToTemplateAlignError(
          f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. '
          'If there are no mmCIF parsing errors, it is possible it was not a '
          'protein chain.')

  try:
361
    parsed_a3m = parsers.parse_a3m(
Augustin-Zidek's avatar
Augustin-Zidek committed
362
        aligner.align([old_template_sequence, new_template_sequence]))
363
    old_aligned_template, new_aligned_template = parsed_a3m.sequences
Augustin-Zidek's avatar
Augustin-Zidek committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
  except Exception as e:
    raise QueryToTemplateAlignError(
        'Could not align old template %s to template %s (%s_%s). Error: %s' %
        (old_template_sequence, new_template_sequence, mmcif_object.file_id,
         template_chain_id, str(e)))

  logging.info('Old aligned template: %s\nNew aligned template: %s',
               old_aligned_template, new_aligned_template)

  old_to_new_template_mapping = {}
  old_template_index = -1
  new_template_index = -1
  num_same = 0
  for old_template_aa, new_template_aa in zip(
      old_aligned_template, new_aligned_template):
    if old_template_aa != '-':
      old_template_index += 1
    if new_template_aa != '-':
      new_template_index += 1
    if old_template_aa != '-' and new_template_aa != '-':
      old_to_new_template_mapping[old_template_index] = new_template_index
      if old_template_aa == new_template_aa:
        num_same += 1

  # Require at least 90 % sequence identity wrt to the shorter of the sequences.
  if float(num_same) / min(
      len(old_template_sequence), len(new_template_sequence)) < 0.9:
    raise QueryToTemplateAlignError(
        'Insufficient similarity of the sequence in the database: %s to the '
        'actual sequence in the mmCIF file %s_%s: %s. We require at least '
        '90 %% similarity wrt to the shorter of the sequences. This is not a '
        'problem unless you think this is a template that should be included.' %
        (old_template_sequence, mmcif_object.file_id, template_chain_id,
         new_template_sequence))

  new_query_to_template_mapping = {}
  for query_index, old_template_index in old_mapping.items():
    new_query_to_template_mapping[query_index] = (
        old_to_new_template_mapping.get(old_template_index, -1))

  new_template_sequence = new_template_sequence.replace('-', '')

  return new_template_sequence, new_query_to_template_mapping


def _check_residue_distances(all_positions: np.ndarray,
                             all_positions_mask: np.ndarray,
                             max_ca_ca_distance: float):
  """Checks if the distance between unmasked neighbor residues is ok."""
  ca_position = residue_constants.atom_order['CA']
  prev_is_unmasked = False
  prev_calpha = None
  for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
    this_is_unmasked = bool(mask[ca_position])
    if this_is_unmasked:
      this_calpha = coords[ca_position]
      if prev_is_unmasked:
        distance = np.linalg.norm(this_calpha - prev_calpha)
        if distance > max_ca_ca_distance:
          raise CaDistanceError(
              'The distance between residues %d and %d is %f > limit %f.' % (
                  i, i + 1, distance, max_ca_ca_distance))
      prev_calpha = this_calpha
    prev_is_unmasked = this_is_unmasked


def _get_atom_positions(
    mmcif_object: mmcif_parsing.MmcifObject,
    auth_chain_id: str,
    max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
  """Gets atom positions and mask from a list of Biopython Residues."""
  num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])

  relevant_chains = [c for c in mmcif_object.structure.get_chains()
                     if c.id == auth_chain_id]
  if len(relevant_chains) != 1:
    raise MultipleChainsError(
        f'Expected exactly one chain in structure with id {auth_chain_id}.')
  chain = relevant_chains[0]

  all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
  all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
                                dtype=np.int64)
  for res_index in range(num_res):
    pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
    mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
    res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
    if not res_at_position.is_missing:
      res = chain[(res_at_position.hetflag,
                   res_at_position.position.residue_number,
                   res_at_position.position.insertion_code)]
      for atom in res.get_atoms():
        atom_name = atom.get_name()
        x, y, z = atom.get_coord()
        if atom_name in residue_constants.atom_order.keys():
          pos[residue_constants.atom_order[atom_name]] = [x, y, z]
          mask[residue_constants.atom_order[atom_name]] = 1.0
        elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
          # Put the coordinates of the selenium atom in the sulphur column.
          pos[residue_constants.atom_order['SD']] = [x, y, z]
          mask[residue_constants.atom_order['SD']] = 1.0

466
467
468
469
470
471
472
473
474
475
476
477
      # Fix naming errors in arginine residues where NH2 is incorrectly
      # assigned to be closer to CD than NH1.
      cd = residue_constants.atom_order['CD']
      nh1 = residue_constants.atom_order['NH1']
      nh2 = residue_constants.atom_order['NH2']
      if (res.get_resname() == 'ARG' and
          all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
          (np.linalg.norm(pos[nh1] - pos[cd]) >
           np.linalg.norm(pos[nh2] - pos[cd]))):
        pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
        mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()

Augustin-Zidek's avatar
Augustin-Zidek committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    all_positions[res_index] = pos
    all_positions_mask[res_index] = mask
  _check_residue_distances(
      all_positions, all_positions_mask, max_ca_ca_distance)
  return all_positions, all_positions_mask


def _extract_template_features(
    mmcif_object: mmcif_parsing.MmcifObject,
    pdb_id: str,
    mapping: Mapping[int, int],
    template_sequence: str,
    query_sequence: str,
    template_chain_id: str,
    kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
  """Parses atom positions in the target structure and aligns with the query.

  Atoms for each residue in the template structure are indexed to coincide
  with their corresponding residue in the query sequence, according to the
  alignment mapping provided.

  Args:
    mmcif_object: mmcif_parsing.MmcifObject representing the template.
    pdb_id: PDB code for the template.
    mapping: Dictionary mapping indices in the query sequence to indices in
      the template sequence.
    template_sequence: String describing the amino acid sequence for the
      template protein.
    query_sequence: String describing the amino acid sequence for the query
      protein.
    template_chain_id: String ID describing which chain in the structure proto
      should be used.
    kalign_binary_path: The path to a kalign executable used for template
        realignment.

  Returns:
    A tuple with:
    * A dictionary containing the extra features derived from the template
      protein structure.
    * A warning message if the hit was realigned to the actual mmCIF sequence.
      Otherwise None.

  Raises:
    NoChainsError: If the mmcif object doesn't contain any chains.
    SequenceNotInTemplateError: If the given chain id / sequence can't
      be found in the mmcif object.
    QueryToTemplateAlignError: If the actual template in the mmCIF file
      can't be aligned to the query.
    NoAtomDataInTemplateError: If the mmcif object doesn't contain
      atom positions.
    TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
      unmasked residues.
  """
  if mmcif_object is None or not mmcif_object.chain_to_seqres:
    raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id))

  warning = None
  try:
    seqres, chain_id, mapping_offset = _find_template_in_pdb(
        template_chain_id=template_chain_id,
        template_sequence=template_sequence,
        mmcif_object=mmcif_object)
  except SequenceNotInTemplateError:
    # If PDB70 contains a different version of the template, we use the sequence
    # from the mmcif_object.
    chain_id = template_chain_id
    warning = (
        f'The exact sequence {template_sequence} was not found in '
        f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.')
    logging.warning(warning)
    # This throws an exception if it fails to realign the hit.
    seqres, mapping = _realign_pdb_template_to_query(
        old_template_sequence=template_sequence,
        template_chain_id=template_chain_id,
        mmcif_object=mmcif_object,
        old_mapping=mapping,
        kalign_binary_path=kalign_binary_path)
    logging.info('Sequence in %s_%s: %s successfully realigned to %s',
                 pdb_id, chain_id, template_sequence, seqres)
    # The template sequence changed.
    template_sequence = seqres
    # No mapping offset, the query is aligned to the actual sequence.
    mapping_offset = 0

  try:
    # Essentially set to infinity - we don't want to reject templates unless
    # they're really really bad.
    all_atom_positions, all_atom_mask = _get_atom_positions(
        mmcif_object, chain_id, max_ca_ca_distance=150.0)
  except (CaDistanceError, KeyError) as ex:
    raise NoAtomDataInTemplateError(
        'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex))
        ) from ex

  all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0])
  all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])

  output_templates_sequence = []
  templates_all_atom_positions = []
  templates_all_atom_masks = []

  for _ in query_sequence:
    # Residues in the query_sequence that are not in the template_sequence:
    templates_all_atom_positions.append(
        np.zeros((residue_constants.atom_type_num, 3)))
    templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num))
    output_templates_sequence.append('-')

  for k, v in mapping.items():
    template_index = v + mapping_offset
    templates_all_atom_positions[k] = all_atom_positions[template_index][0]
    templates_all_atom_masks[k] = all_atom_masks[template_index][0]
    output_templates_sequence[k] = template_sequence[v]

  # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
  if np.sum(templates_all_atom_masks) < 5:
    raise TemplateAtomMaskAllZerosError(
        'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' %
        (pdb_id, chain_id, min(mapping.values()) + mapping_offset,
         max(mapping.values()) + mapping_offset))

  output_templates_sequence = ''.join(output_templates_sequence)

  templates_aatype = residue_constants.sequence_to_onehot(
      output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)

  return (
605
606
607
608
609
610
611
      {
          'template_all_atom_positions': np.array(templates_all_atom_positions),
          'template_all_atom_masks': np.array(templates_all_atom_masks),
          'template_sequence': output_templates_sequence.encode(),
          'template_aatype': np.array(templates_aatype),
          'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
      },
Augustin-Zidek's avatar
Augustin-Zidek committed
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
      warning)


def _build_query_to_hit_index_mapping(
    hit_query_sequence: str,
    hit_sequence: str,
    indices_hit: Sequence[int],
    indices_query: Sequence[int],
    original_query_sequence: str) -> Mapping[int, int]:
  """Gets mapping from indices in original query sequence to indices in the hit.

  hit_query_sequence and hit_sequence are two aligned sequences containing gap
  characters. hit_query_sequence contains only the part of the original query
  sequence that matched the hit. When interpreting the indices from the .hhr, we
  need to correct for this to recover a mapping from original query sequence to
  the hit sequence.

  Args:
    hit_query_sequence: The portion of the query sequence that is in the .hhr
      hit
    hit_sequence: The portion of the hit sequence that is in the .hhr
    indices_hit: The indices for each aminoacid relative to the hit sequence
    indices_query: The indices for each aminoacid relative to the original query
      sequence
    original_query_sequence: String describing the original query sequence.

  Returns:
    Dictionary with indices in the original query sequence as keys and indices
    in the hit sequence as values.
  """
  # If the hit is empty (no aligned residues), return empty mapping
  if not hit_query_sequence:
    return {}

  # Remove gaps and find the offset of hit.query relative to original query.
  hhsearch_query_sequence = hit_query_sequence.replace('-', '')
  hit_sequence = hit_sequence.replace('-', '')
  hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence)

  # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
  min_idx = min(x for x in indices_hit if x > -1)
  fixed_indices_hit = [
      x - min_idx if x > -1 else -1 for x in indices_hit
  ]

  min_idx = min(x for x in indices_query if x > -1)
  fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]

  # Zip the corrected indices, ignore case where both seqs have gap characters.
  mapping = {}
  for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
    if q_t != -1 and q_i != -1:
      if (q_t >= len(hit_sequence) or
          q_i + hhsearch_query_offset >= len(original_query_sequence)):
        continue
      mapping[q_i + hhsearch_query_offset] = q_t

  return mapping


@dataclasses.dataclass(frozen=True)
class SingleHitResult:
  features: Optional[Mapping[str, Any]]
  error: Optional[str]
  warning: Optional[str]


679
680
681
682
683
684
685
@functools.lru_cache(16, typed=False)
def _read_file(path):
  with open(path, 'r') as f:
    file_data = f.read()
  return file_data


Augustin-Zidek's avatar
Augustin-Zidek committed
686
687
def _process_single_hit(
    query_sequence: str,
688
    hit: parsers.TemplateHit,
Augustin-Zidek's avatar
Augustin-Zidek committed
689
690
691
    mmcif_dir: str,
    max_template_date: datetime.datetime,
    release_dates: Mapping[str, datetime.datetime],
692
    obsolete_pdbs: Mapping[str, Optional[str]],
Augustin-Zidek's avatar
Augustin-Zidek committed
693
694
695
696
697
698
    kalign_binary_path: str,
    strict_error_check: bool = False) -> SingleHitResult:
  """Tries to extract template features from a single HHSearch hit."""
  # Fail hard if we can't get the PDB ID and chain name from the hit.
  hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)

699
700
701
702
703
  # This hit has been removed (obsoleted) from PDB, skip it.
  if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None:
    return SingleHitResult(
        features=None, error=None, warning=f'Hit {hit_pdb_code} is obsolete.')

Augustin-Zidek's avatar
Augustin-Zidek committed
704
705
706
707
708
709
710
711
712
713
714
715
716
717
  if hit_pdb_code not in release_dates:
    if hit_pdb_code in obsolete_pdbs:
      hit_pdb_code = obsolete_pdbs[hit_pdb_code]

  # Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
  try:
    _assess_hhsearch_hit(
        hit=hit,
        hit_pdb_code=hit_pdb_code,
        query_sequence=query_sequence,
        release_dates=release_dates,
        release_date_cutoff=max_template_date)
  except PrefilterError as e:
    msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
718
719
    logging.info(msg)
    if strict_error_check and isinstance(e, (DateError, DuplicateError)):
Augustin-Zidek's avatar
Augustin-Zidek committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
      # In strict mode we treat some prefilter cases as errors.
      return SingleHitResult(features=None, error=msg, warning=None)

    return SingleHitResult(features=None, error=None, warning=None)

  mapping = _build_query_to_hit_index_mapping(
      hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,
      query_sequence)

  # The mapping is from the query to the actual hit sequence, so we need to
  # remove gaps (which regardless have a missing confidence score).
  template_sequence = hit.hit_sequence.replace('-', '')

  cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
734
735
  logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path,
                query_sequence, template_sequence)
Augustin-Zidek's avatar
Augustin-Zidek committed
736
  # Fail if we can't find the mmCIF file.
737
  cif_string = _read_file(cif_path)
Augustin-Zidek's avatar
Augustin-Zidek committed
738
739
740
741
742
743
744
745
746
747
748
749
750

  parsing_result = mmcif_parsing.parse(
      file_id=hit_pdb_code, mmcif_string=cif_string)

  if parsing_result.mmcif_object is not None:
    hit_release_date = datetime.datetime.strptime(
        parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
    if hit_release_date > max_template_date:
      error = ('Template %s date (%s) > max template date (%s).' %
               (hit_pdb_code, hit_release_date, max_template_date))
      if strict_error_check:
        return SingleHitResult(features=None, error=error, warning=None)
      else:
751
        logging.debug(error)
Augustin-Zidek's avatar
Augustin-Zidek committed
752
753
754
755
756
757
758
759
760
761
762
        return SingleHitResult(features=None, error=None, warning=None)

  try:
    features, realign_warning = _extract_template_features(
        mmcif_object=parsing_result.mmcif_object,
        pdb_id=hit_pdb_code,
        mapping=mapping,
        template_sequence=template_sequence,
        query_sequence=query_sequence,
        template_chain_id=hit_chain_id,
        kalign_binary_path=kalign_binary_path)
763
764
765
766
    if hit.sum_probs is None:
      features['template_sum_probs'] = [0]
    else:
      features['template_sum_probs'] = [hit.sum_probs]
Augustin-Zidek's avatar
Augustin-Zidek committed
767
768
769
770
771
772
773
774
775
776

    # It is possible there were some errors when parsing the other chains in the
    # mmCIF file, but the template features for the chain we want were still
    # computed. In such case the mmCIF parsing errors are not relevant.
    return SingleHitResult(
        features=features, error=None, warning=realign_warning)
  except (NoChainsError, NoAtomDataInTemplateError,
          TemplateAtomMaskAllZerosError) as e:
    # These 3 errors indicate missing mmCIF experimental data rather than a
    # problem with the template search, so turn them into warnings.
777
    warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
Augustin-Zidek's avatar
Augustin-Zidek committed
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
               '%s, mmCIF parsing errors: %s'
               % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
                  str(e), parsing_result.errors))
    if strict_error_check:
      return SingleHitResult(features=None, error=warning, warning=None)
    else:
      return SingleHitResult(features=None, error=None, warning=warning)
  except Error as e:
    error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
             '%s, mmCIF parsing errors: %s'
             % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
                str(e), parsing_result.errors))
    return SingleHitResult(features=None, error=error, warning=None)


@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
  features: Mapping[str, Any]
  errors: Sequence[str]
  warnings: Sequence[str]


800
801
class TemplateHitFeaturizer(abc.ABC):
  """An abstract base class for turning template hits to template features."""
Augustin-Zidek's avatar
Augustin-Zidek committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

  def __init__(
      self,
      mmcif_dir: str,
      max_template_date: str,
      max_hits: int,
      kalign_binary_path: str,
      release_dates_path: Optional[str],
      obsolete_pdbs_path: Optional[str],
      strict_error_check: bool = False):
    """Initializes the Template Search.

    Args:
      mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
        is found by HHSearch, this directory is used to retrieve the template
        data.
      max_template_date: The maximum date permitted for template structures. No
        template with date higher than this date will be returned. In ISO8601
        date format, YYYY-MM-DD.
      max_hits: The maximum number of templates that will be returned.
      kalign_binary_path: The path to a kalign executable used for template
        realignment.
      release_dates_path: An optional path to a file with a mapping from PDB IDs
        to their release dates. Thanks to this we don't have to redundantly
        parse mmCIF files to get that information.
      obsolete_pdbs_path: An optional path to a file containing a mapping from
        obsolete PDB IDs to the PDB IDs of their replacements.
      strict_error_check: If True, then the following will be treated as errors:
        * If any template date is after the max_template_date.
        * If any template has identical PDB ID to the query.
        * If any template is a duplicate of the query.
        * Any feature computation errors.
    """
    self._mmcif_dir = mmcif_dir
    if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
      logging.error('Could not find CIFs in %s', self._mmcif_dir)
      raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')

    try:
      self._max_template_date = datetime.datetime.strptime(
          max_template_date, '%Y-%m-%d')
    except ValueError:
      raise ValueError(
          'max_template_date must be set and have format YYYY-MM-DD.')
    self._max_hits = max_hits
    self._kalign_binary_path = kalign_binary_path
    self._strict_error_check = strict_error_check

    if release_dates_path:
      logging.info('Using precomputed release dates %s.', release_dates_path)
      self._release_dates = _parse_release_dates(release_dates_path)
    else:
      self._release_dates = {}

    if obsolete_pdbs_path:
      logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path)
      self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
    else:
      self._obsolete_pdbs = {}

862
863
864
865
866
867
868
869
870
871
872
  @abc.abstractmethod
  def get_templates(
      self,
      query_sequence: str,
      hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
    """Computes the templates for given query sequence."""


class HhsearchHitFeaturizer(TemplateHitFeaturizer):
  """A class for turning a3m hits from hhsearch to template features."""

Augustin-Zidek's avatar
Augustin-Zidek committed
873
874
875
  def get_templates(
      self,
      query_sequence: str,
876
      hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
Augustin-Zidek's avatar
Augustin-Zidek committed
877
    """Computes the templates for given query sequence (more details above)."""
878
    logging.info('Searching for template for: %s', query_sequence)
Augustin-Zidek's avatar
Augustin-Zidek committed
879
880
881
882
883
884
885
886
887

    template_features = {}
    for template_feature_name in TEMPLATE_FEATURES:
      template_features[template_feature_name] = []

    num_hits = 0
    errors = []
    warnings = []

888
889
    for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
      # We got all the templates we wanted, stop processing hits.
Augustin-Zidek's avatar
Augustin-Zidek committed
890
891
892
893
894
895
896
      if num_hits >= self._max_hits:
        break

      result = _process_single_hit(
          query_sequence=query_sequence,
          hit=hit,
          mmcif_dir=self._mmcif_dir,
897
          max_template_date=self._max_template_date,
Augustin-Zidek's avatar
Augustin-Zidek committed
898
899
900
901
902
903
904
905
906
          release_dates=self._release_dates,
          obsolete_pdbs=self._obsolete_pdbs,
          strict_error_check=self._strict_error_check,
          kalign_binary_path=self._kalign_binary_path)

      if result.error:
        errors.append(result.error)

      # There could be an error even if there are some results, e.g. thrown by
Augustin Zidek's avatar
Augustin Zidek committed
907
      # other unparsable chains in the same mmCIF file.
Augustin-Zidek's avatar
Augustin-Zidek committed
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
      if result.warning:
        warnings.append(result.warning)

      if result.features is None:
        logging.info('Skipped invalid hit %s, error: %s, warning: %s',
                     hit.name, result.error, result.warning)
      else:
        # Increment the hit counter, since we got features out of this hit.
        num_hits += 1
        for k in template_features:
          template_features[k].append(result.features[k])

    for name in template_features:
      if num_hits > 0:
        template_features[name] = np.stack(
            template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
      else:
        # Make sure the feature has correct dtype even if empty.
        template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])

    return TemplateSearchResult(
        features=template_features, errors=errors, warnings=warnings)
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010


class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
  """A class for turning a3m hits from hmmsearch to template features."""

  def get_templates(
      self,
      query_sequence: str,
      hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
    """Computes the templates for given query sequence (more details above)."""
    logging.info('Searching for template for: %s', query_sequence)

    template_features = {}
    for template_feature_name in TEMPLATE_FEATURES:
      template_features[template_feature_name] = []

    already_seen = set()
    errors = []
    warnings = []

    if not hits or hits[0].sum_probs is None:
      sorted_hits = hits
    else:
      sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)

    for hit in sorted_hits:
      # We got all the templates we wanted, stop processing hits.
      if len(already_seen) >= self._max_hits:
        break

      result = _process_single_hit(
          query_sequence=query_sequence,
          hit=hit,
          mmcif_dir=self._mmcif_dir,
          max_template_date=self._max_template_date,
          release_dates=self._release_dates,
          obsolete_pdbs=self._obsolete_pdbs,
          strict_error_check=self._strict_error_check,
          kalign_binary_path=self._kalign_binary_path)

      if result.error:
        errors.append(result.error)

      # There could be an error even if there are some results, e.g. thrown by
      # other unparsable chains in the same mmCIF file.
      if result.warning:
        warnings.append(result.warning)

      if result.features is None:
        logging.debug('Skipped invalid hit %s, error: %s, warning: %s',
                      hit.name, result.error, result.warning)
      else:
        already_seen_key = result.features['template_sequence']
        if already_seen_key in already_seen:
          continue
        # Increment the hit counter, since we got features out of this hit.
        already_seen.add(already_seen_key)
        for k in template_features:
          template_features[k].append(result.features[k])

    if already_seen:
      for name in template_features:
        template_features[name] = np.stack(
            template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
    else:
      num_res = len(query_sequence)
      # Construct a default template with all zeros.
      template_features = {
          'template_aatype': np.zeros(
              (1, num_res, len(residue_constants.restypes_with_x_and_gap)),
              np.float32),
          'template_all_atom_masks': np.zeros(
              (1, num_res, residue_constants.atom_type_num), np.float32),
          'template_all_atom_positions': np.zeros(
              (1, num_res, residue_constants.atom_type_num, 3), np.float32),
          'template_domain_names': np.array([''.encode()], dtype=np.object),
          'template_sequence': np.array([''.encode()], dtype=np.object),
          'template_sum_probs': np.array([0], dtype=np.float32)
      }
    return TemplateSearchResult(
        features=template_features, errors=errors, warnings=warnings)