data_pipeline.py 14.7 KB
Newer Older
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
import os
import datetime
18
from typing import Mapping, Optional, Sequence, Any
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21

import numpy as np

22
23
24
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
25
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
29


FeatureDict = Mapping[str, np.ndarray]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
def make_sequence_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
    sequence: str, description: str, num_res: int
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
34
35
) -> FeatureDict:
    """Construct a feature dict of sequence features."""
    features = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
    features["aatype"] = residue_constants.sequence_to_onehot(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
37
38
        sequence=sequence,
        mapping=residue_constants.restype_order_with_x,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
        map_unknown_to_x=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
43
    features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
    features["domain_name"] = np.array(
        [description.encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
46
47
48
    features["residue_index"] = np.array(range(num_res), dtype=np.int32)
    features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
    features["sequence"] = np.array(
        [sequence.encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
50
51
52
53
    )
    return features


def make_mmcif_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
    mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
) -> FeatureDict:
    input_sequence = mmcif_object.chain_to_seqres[chain_id]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57
    description = "_".join([mmcif_object.file_id, chain_id])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
    num_res = len(input_sequence)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
    mmcif_feats = {}

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
63
64
65
66
67
68
69
    mmcif_feats.update(
        make_sequence_features(
            sequence=input_sequence,
            description=description,
            num_res=num_res,
        )
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72
73
74
    all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
        mmcif_object=mmcif_object, chain_id=chain_id
    )
    mmcif_feats["all_atom_positions"] = all_atom_positions
    mmcif_feats["all_atom_mask"] = all_atom_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
76
77
78
79
80
    mmcif_feats["resolution"] = np.array(
        [mmcif_object.header["resolution"]], dtype=np.float32
    )

    mmcif_feats["release_date"] = np.array(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
        [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
83
    )

84
85
    mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
87
88
    return mmcif_feats


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def make_pdb_features(
        protein_object: protein.Protein, 
        description: str, 
        confidence_threshold: float = 0.5,
) -> FeatureDict:
    pdb_feats = {}

    pdb_feats.update(
        make_sequence_features(
            sequence=protein_object.aatype,
            description=description,
            num_res=len(protein_object.aatype),
        )
    )

    all_atom_positions = protein_object.atom_positions
    all_atom_mask = protein_object.atom_mask

    high_confidence = protein.b_factors > confidence_threshold
    high_confidence = np.any(high_confidence, axis=-1)
    for i, confident in enumerate(high_confidence):
        if(not confident):
            all_atom_mask[i] = 0

    pdb_feats["all_atom_positions"] = all_atom_positions
    pdb_feats["all_atom_mask"] = all_atom_mask

    pdb_feats["is_distillation"] = np.array(1.).astype(np.float32)

    return pdb_feats


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
def make_msa_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
123
124
    msas: Sequence[Sequence[str]],
    deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125
126
    """Constructs a feature dict of MSA features."""
    if not msas:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
        raise ValueError("At least one MSA must be provided.")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
128
129
130
131
132
133

    int_msa = []
    deletion_matrix = []
    seen_sequences = set()
    for msa_index, msa in enumerate(msas):
        if not msa:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134
135
136
            raise ValueError(
                f"MSA {msa_index} must contain at least one sequence."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
141
142
143
144
145
146
147
148
        for sequence_index, sequence in enumerate(msa):
            if sequence in seen_sequences:
                continue
            seen_sequences.add(sequence)
            int_msa.append(
                [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
            )
            deletion_matrix.append(deletion_matrices[msa_index][sequence_index])

    num_res = len(msas[0][0])
    num_alignments = len(int_msa)
    features = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
150
151
    features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
    features["msa"] = np.array(int_msa, dtype=np.int32)
    features["num_alignments"] = np.array(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
153
154
155
156
157
        [num_alignments] * num_res, dtype=np.int32
    )
    return features


class AlignmentRunner:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    """Runs alignment tools and saves the results"""

    def __init__(
        self,
        jackhmmer_binary_path: str,
        hhblits_binary_path: str,
        hhsearch_binary_path: str,
        uniref90_database_path: str,
        mgnify_database_path: str,
        bfd_database_path: Optional[str],
        uniclust30_database_path: Optional[str],
        small_bfd_database_path: Optional[str],
        pdb70_database_path: str,
        use_small_bfd: bool,
        no_cpus: int,
        uniref_max_hits: int = 10000,
        mgnify_max_hits: int = 5000,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ):
        self._use_small_bfd = use_small_bfd
        self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
            binary_path=jackhmmer_binary_path,
            database_path=uniref90_database_path,
            n_cpu=no_cpus,
        )

        if use_small_bfd:
            self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
                binary_path=jackhmmer_binary_path,
                database_path=small_bfd_database_path,
                n_cpu=no_cpus,
            )
        else:
            self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
                binary_path=hhblits_binary_path,
                databases=[bfd_database_path, uniclust30_database_path],
                n_cpu=no_cpus,
            )

        self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
            binary_path=jackhmmer_binary_path,
            database_path=mgnify_database_path,
            n_cpu=no_cpus,
        )

        self.hhsearch_pdb70_runner = hhsearch.HHSearch(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
            binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
204
205
206
207
        )
        self.uniref_max_hits = uniref_max_hits
        self.mgnify_max_hits = mgnify_max_hits

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
208
209
    def run(
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
211
212
213
        fasta_path: str,
        output_dir: str,
    ):
        """Runs alignment tools on a sequence"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214
215
216
        jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
            fasta_path
        )[0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
        uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218
            jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
221
        uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
        with open(uniref90_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
223
            f.write(uniref90_msa_as_a3m)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
        jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
            fasta_path
        )[0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
        mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
            jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
231
        mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
        with open(mgnify_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
233
234
            f.write(mgnify_msa_as_a3m)

        hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
235
236
        pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
        with open(pdb70_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237
238
239
            f.write(hhsearch_result)

        if self._use_small_bfd:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
241
242
243
244
245
            jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
                fasta_path
            )[0]
            bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
            with open(bfd_out_path, "w") as f:
                f.write(jackhmmer_small_bfd_result["sto"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246
        else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
248
249
250
251
252
253
            hhblits_bfd_uniclust_result = (
                self.hhblits_bfd_uniclust_runner.query(fasta_path)
            )
            if output_dir is not None:
                bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
                with open(bfd_out_path, "w") as f:
                    f.write(hhblits_bfd_uniclust_result["a3m"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
255
256
257


class DataPipeline:
    """Assembles input features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258
259
260
261
262

    def __init__(
        self,
        template_featurizer: templates.TemplateHitFeaturizer,
        use_small_bfd: bool,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
264
265
266
    ):
        self.template_featurizer = template_featurizer
        self.use_small_bfd = use_small_bfd

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
    def _parse_alignment_output(
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269
270
        alignment_dir: str,
    ) -> Mapping[str, Any]:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
271
272
273
        uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
        with open(uniref90_out_path, "r") as f:
            uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
276
277
        mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
        with open(mgnify_out_path, "r") as f:
            mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279
280
281
        pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
        with open(pdb70_out_path, "r") as f:
            hhsearch_hits = parsers.parse_hhr(f.read())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
284
285
        if self.use_small_bfd:
            bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
            with open(bfd_out_path, "r") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
287
288
289
                bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
                    f.read()
                )
        else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
291
292
            bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
            with open(bfd_out_path, "r") as f:
                bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
294

        return {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
297
298
299
300
301
            "uniref90_msa": uniref90_msa,
            "uniref90_deletion_matrix": uniref90_deletion_matrix,
            "mgnify_msa": mgnify_msa,
            "mgnify_deletion_matrix": mgnify_deletion_matrix,
            "hhsearch_hits": hhsearch_hits,
            "bfd_msa": bfd_msa,
            "bfd_deletion_matrix": bfd_deletion_matrix,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
302
303
        }

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
305
    def process_fasta(
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
307
308
309
310
        fasta_path: str,
        alignment_dir: str,
    ) -> FeatureDict:
        """Assembles features for a single sequence in a FASTA file"""
        with open(fasta_path) as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311
            fasta_str = f.read()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312
313
        input_seqs, input_descs = parsers.parse_fasta(fasta_str)
        if len(input_seqs) != 1:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
314
315
316
            raise ValueError(
                f"More than one input sequence found in {fasta_path}."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
317
318
319
320
321
322
323
324
325
326
        input_sequence = input_seqs[0]
        input_description = input_descs[0]
        num_res = len(input_sequence)

        alignments = self._parse_alignment_output(alignment_dir)

        templates_result = self.template_featurizer.get_templates(
            query_sequence=input_sequence,
            query_pdb_code=None,
            query_release_date=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
327
            hits=alignments["hhsearch_hits"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328
329
330
331
332
        )

        sequence_features = make_sequence_features(
            sequence=input_sequence,
            description=input_description,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
            num_res=num_res,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
335
336
337
        )

        msa_features = make_msa_features(
            msas=(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
339
340
                alignments["uniref90_msa"],
                alignments["bfd_msa"],
                alignments["mgnify_msa"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
341
342
            ),
            deletion_matrices=(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343
344
345
346
                alignments["uniref90_deletion_matrix"],
                alignments["bfd_deletion_matrix"],
                alignments["mgnify_deletion_matrix"],
            ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
        )
348
349
350
351
352
        return {
            **sequence_features,
            **msa_features, 
            **templates_result.features
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354
355
356
    def process_mmcif(
        self,
        mmcif: mmcif_parsing.MmcifObject,  # parsing is expensive, so no path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357
358
359
360
        alignment_dir: str,
        chain_id: Optional[str] = None,
    ) -> FeatureDict:
        """
361
            Assembles features for a specific chain in an mmCIF object.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362

363
364
            If chain_id is None, it is assumed that there is only one chain
            in the object. Otherwise, a ValueError is thrown.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
365
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
        if chain_id is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
367
368
            chains = mmcif.structure.get_chains()
            chain = next(chains, None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
369
370
            if chain is None:
                raise ValueError("No chains in mmCIF file")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
373
374
375
376
377
378
379
380
381
            chain_id = chain.id

        mmcif_feats = make_mmcif_features(mmcif, chain_id)

        alignments = self._parse_alignment_output(alignment_dir)

        input_sequence = mmcif.chain_to_seqres[chain_id]
        templates_result = self.template_featurizer.get_templates(
            query_sequence=input_sequence,
            query_pdb_code=None,
            query_release_date=to_date(mmcif.header["release_date"]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
            hits=alignments["hhsearch_hits"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
385
386
        )

        msa_features = make_msa_features(
            msas=(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387
388
389
390
391
392
393
394
                alignments["uniref90_msa"],
                alignments["bfd_msa"],
                alignments["mgnify_msa"],
            ),
            deletion_matrices=(
                alignments["uniref90_deletion_matrix"],
                alignments["bfd_deletion_matrix"],
                alignments["mgnify_deletion_matrix"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
397
            ),
        )

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        return {**mmcif_feats, **templates_result.features, **msa_features}

    def process_pdb(
        self,
        pdb_path: str,
        alignment_dir: str,
    ) -> FeatureDict:
        """
            Assembles features for a protein in a PDB file.
        """
        with open(pdb_path, 'r') as f:
            pdb_str = pdb_path

        protein_object = protein.from_pdb_string(pdb_str)

        pdb_feats = make_pdb_features(protein_object)


        mmcif_feats = make_mmcif_features(mmcif, chain_id)

        alignments = self._parse_alignment_output(alignment_dir)

        input_sequence = mmcif.chain_to_seqres[chain_id]
        templates_result = self.template_featurizer.get_templates(
            query_sequence=input_sequence,
            query_pdb_code=None,
            query_release_date=to_date(mmcif.header["release_date"]),
            hits=alignments["hhsearch_hits"],
        )

        msa_features = make_msa_features(
            msas=(
                alignments["uniref90_msa"],
                alignments["bfd_msa"],
                alignments["mgnify_msa"],
            ),
            deletion_matrices=(
                alignments["uniref90_deletion_matrix"],
                alignments["bfd_deletion_matrix"],
                alignments["mgnify_deletion_matrix"],
            ),
        )

        return {**mmcif_feats, **templates_result.features, **msa_features}