protein.py 9.73 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Protein data type."""
Tom Ward's avatar
Tom Ward committed
16
import dataclasses
Augustin-Zidek's avatar
Augustin-Zidek committed
17
18
import io
from typing import Any, Mapping, Optional
Tom Ward's avatar
Tom Ward committed
19
from alphafold.common import residue_constants
Augustin-Zidek's avatar
Augustin-Zidek committed
20
21
22
23
24
25
from Bio.PDB import PDBParser
import numpy as np

FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any]  # Is a nested dict.

26
27
28
29
# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)  # := 62.

Augustin-Zidek's avatar
Augustin-Zidek committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

@dataclasses.dataclass(frozen=True)
class Protein:
  """Protein structure representation."""

  # Cartesian coordinates of atoms in angstroms. The atom types correspond to
  # residue_constants.atom_types, i.e. the first three are N, CA, CB.
  atom_positions: np.ndarray  # [num_res, num_atom_type, 3]

  # Amino-acid type for each residue represented as an integer between 0 and
  # 20, where 20 is 'X'.
  aatype: np.ndarray  # [num_res]

  # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
  # is present and 0.0 if not. This should be used for loss masking.
  atom_mask: np.ndarray  # [num_res, num_atom_type]

  # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
  residue_index: np.ndarray  # [num_res]

50
51
52
53
  # 0-indexed number corresponding to the chain in the protein that this residue
  # belongs to.
  chain_index: np.ndarray  # [num_res]

Augustin-Zidek's avatar
Augustin-Zidek committed
54
55
56
57
58
  # B-factors, or temperature factors, of each residue (in sq. angstroms units),
  # representing the displacement of the residue from its ground truth mean
  # value.
  b_factors: np.ndarray  # [num_res, num_atom_type]

59
60
61
62
63
64
  def __post_init__(self):
    if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
      raise ValueError(
          f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
          'because these cannot be written to PDB format.')

Augustin-Zidek's avatar
Augustin-Zidek committed
65
66
67
68
69
70
71
72
73

def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
  """Takes a PDB string and constructs a Protein object.

  WARNING: All non-standard residue types will be converted into UNK. All
    non-standard atoms will be ignored.

  Args:
    pdb_str: The contents of the pdb file
74
75
    chain_id: If chain_id is specified (e.g. A), then only that chain
      is parsed. Otherwise all chains are parsed.
Augustin-Zidek's avatar
Augustin-Zidek committed
76
77
78
79
80

  Returns:
    A new `Protein` parsed from the pdb contents.
  """
  pdb_fh = io.StringIO(pdb_str)
81
  parser = PDBParser(QUIET=True)
Augustin-Zidek's avatar
Augustin-Zidek committed
82
83
84
85
86
87
88
89
90
91
92
  structure = parser.get_structure('none', pdb_fh)
  models = list(structure.get_models())
  if len(models) != 1:
    raise ValueError(
        f'Only single model PDBs are supported. Found {len(models)} models.')
  model = models[0]

  atom_positions = []
  aatype = []
  atom_mask = []
  residue_index = []
93
  chain_ids = []
Augustin-Zidek's avatar
Augustin-Zidek committed
94
95
  b_factors = []

96
97
  for chain in model:
    if chain_id is not None and chain.id != chain_id:
Augustin-Zidek's avatar
Augustin-Zidek committed
98
      continue
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    for res in chain:
      if res.id[2] != ' ':
        raise ValueError(
            f'PDB contains an insertion code at chain {chain.id} and residue '
            f'index {res.id[1]}. These are not supported.')
      res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
      restype_idx = residue_constants.restype_order.get(
          res_shortname, residue_constants.restype_num)
      pos = np.zeros((residue_constants.atom_type_num, 3))
      mask = np.zeros((residue_constants.atom_type_num,))
      res_b_factors = np.zeros((residue_constants.atom_type_num,))
      for atom in res:
        if atom.name not in residue_constants.atom_types:
          continue
        pos[residue_constants.atom_order[atom.name]] = atom.coord
        mask[residue_constants.atom_order[atom.name]] = 1.
        res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
      if np.sum(mask) < 0.5:
        # If no known atom positions are reported for the residue then skip it.
        continue
      aatype.append(restype_idx)
      atom_positions.append(pos)
      atom_mask.append(mask)
      residue_index.append(res.id[1])
      chain_ids.append(chain.id)
      b_factors.append(res_b_factors)

  # Chain IDs are usually characters so map these to ints.
  unique_chain_ids = np.unique(chain_ids)
  chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
  chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
Augustin-Zidek's avatar
Augustin-Zidek committed
130
131
132
133
134
135

  return Protein(
      atom_positions=np.array(atom_positions),
      atom_mask=np.array(atom_mask),
      aatype=np.array(aatype),
      residue_index=np.array(residue_index),
136
      chain_index=chain_index,
Augustin-Zidek's avatar
Augustin-Zidek committed
137
138
139
      b_factors=np.array(b_factors))


140
141
142
143
144
145
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
  chain_end = 'TER'
  return (f'{chain_end:<6}{atom_index:>5}      {end_resname:>3} '
          f'{chain_name:>1}{residue_index:>4}')


Augustin-Zidek's avatar
Augustin-Zidek committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def to_pdb(prot: Protein) -> str:
  """Converts a `Protein` instance to a PDB string.

  Args:
    prot: The protein to convert to PDB.

  Returns:
    PDB string.
  """
  restypes = residue_constants.restypes + ['X']
  res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
  atom_types = residue_constants.atom_types

  pdb_lines = []

  atom_mask = prot.atom_mask
  aatype = prot.aatype
  atom_positions = prot.atom_positions
  residue_index = prot.residue_index.astype(np.int32)
165
  chain_index = prot.chain_index.astype(np.int32)
Augustin-Zidek's avatar
Augustin-Zidek committed
166
167
168
169
170
  b_factors = prot.b_factors

  if np.any(aatype > residue_constants.restype_num):
    raise ValueError('Invalid aatypes.')

171
172
173
174
175
176
177
178
  # Construct a mapping from chain integer indices to chain ID strings.
  chain_ids = {}
  for i in np.unique(chain_index):  # np.unique gives sorted output.
    if i >= PDB_MAX_CHAINS:
      raise ValueError(
          f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
    chain_ids[i] = PDB_CHAIN_IDS[i]

Augustin-Zidek's avatar
Augustin-Zidek committed
179
180
  pdb_lines.append('MODEL     1')
  atom_index = 1
181
  last_chain_index = chain_index[0]
Augustin-Zidek's avatar
Augustin-Zidek committed
182
183
  # Add all atom sites.
  for i in range(aatype.shape[0]):
184
185
186
187
188
189
190
191
    # Close the previous chain if in a multichain PDB.
    if last_chain_index != chain_index[i]:
      pdb_lines.append(_chain_end(
          atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
          residue_index[i - 1]))
      last_chain_index = chain_index[i]
      atom_index += 1  # Atom index increases at the TER symbol.

Augustin-Zidek's avatar
Augustin-Zidek committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    res_name_3 = res_1to3(aatype[i])
    for atom_name, pos, mask, b_factor in zip(
        atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
      if mask < 0.5:
        continue

      record_type = 'ATOM'
      name = atom_name if len(atom_name) == 4 else f' {atom_name}'
      alt_loc = ''
      insertion_code = ''
      occupancy = 1.00
      element = atom_name[0]  # Protein supports only C, N, O, S, this works.
      charge = ''
      # PDB is a columnar format, every space matters here!
      atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
207
                   f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
Augustin-Zidek's avatar
Augustin-Zidek committed
208
209
210
211
212
213
214
                   f'{residue_index[i]:>4}{insertion_code:>1}   '
                   f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
                   f'{occupancy:>6.2f}{b_factor:>6.2f}          '
                   f'{element:>2}{charge:>2}')
      pdb_lines.append(atom_line)
      atom_index += 1

215
216
217
  # Close the final chain.
  pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
                              chain_ids[chain_index[-1]], residue_index[-1]))
Augustin-Zidek's avatar
Augustin-Zidek committed
218
219
  pdb_lines.append('ENDMDL')
  pdb_lines.append('END')
220
221
222
223

  # Pad all lines to 80 characters.
  pdb_lines = [line.ljust(80) for line in pdb_lines]
  return '\n'.join(pdb_lines) + '\n'  # Add terminating newline.
Augustin-Zidek's avatar
Augustin-Zidek committed
224
225
226
227
228
229
230


def ideal_atom_mask(prot: Protein) -> np.ndarray:
  """Computes an ideal atom mask.

  `Protein.atom_mask` typically is defined according to the atoms that are
  reported in the PDB. This function computes a mask according to heavy atoms
Augustin Zidek's avatar
Augustin Zidek committed
231
  that should be present in the given sequence of amino acids.
Augustin-Zidek's avatar
Augustin-Zidek committed
232
233
234
235
236
237
238
239
240
241

  Args:
    prot: `Protein` whose fields are `numpy.ndarray` objects.

  Returns:
    An ideal atom mask.
  """
  return residue_constants.STANDARD_ATOM_MASK[prot.aatype]


242
243
244
245
246
def from_prediction(
    features: FeatureDict,
    result: ModelOutput,
    b_factors: Optional[np.ndarray] = None,
    remove_leading_feature_dimension: bool = True) -> Protein:
Augustin-Zidek's avatar
Augustin-Zidek committed
247
248
249
250
251
  """Assembles a protein from a prediction.

  Args:
    features: Dictionary holding model inputs.
    result: Dictionary holding model outputs.
252
    b_factors: (Optional) B-factors to use for the protein.
253
254
    remove_leading_feature_dimension: Whether to remove the leading dimension
      of the `features` values.
Augustin-Zidek's avatar
Augustin-Zidek committed
255
256
257
258
259

  Returns:
    A protein instance.
  """
  fold_output = result['structure_module']
260
261
262
263
264
265
266
267
268

  def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
    return arr[0] if remove_leading_feature_dimension else arr

  if 'asym_id' in features:
    chain_index = _maybe_remove_leading_dim(features['asym_id'])
  else:
    chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))

269
270
  if b_factors is None:
    b_factors = np.zeros_like(fold_output['final_atom_mask'])
Augustin-Zidek's avatar
Augustin-Zidek committed
271
272

  return Protein(
273
      aatype=_maybe_remove_leading_dim(features['aatype']),
Augustin-Zidek's avatar
Augustin-Zidek committed
274
275
      atom_positions=fold_output['final_atom_positions'],
      atom_mask=fold_output['final_atom_mask'],
276
277
      residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
      chain_index=chain_index,
278
      b_factors=b_factors)