data_pipeline.py 15.5 KB
Newer Older
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
import os
import datetime
18
from typing import Mapping, Optional, Sequence, Any
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21

import numpy as np

22
23
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
24
from openfold.data.tools.utils import to_date 
25
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
29


FeatureDict = Mapping[str, np.ndarray]

30
31
32
33
34
35
36
37
38
def empty_template_feats(n_res) -> FeatureDict:
    return {
        "template_aatype": np.zeros((0, n_res)).astype(np.int64),
        "template_all_atom_positions": 
            np.zeros((0, n_res, 37, 3)).astype(np.float32),
        "template_sum_probs": np.zeros((0, 1)).astype(np.float32),
        "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
    }

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
def make_sequence_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
    sequence: str, description: str, num_res: int
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
43
44
) -> FeatureDict:
    """Construct a feature dict of sequence features."""
    features = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
    features["aatype"] = residue_constants.sequence_to_onehot(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
        sequence=sequence,
        mapping=residue_constants.restype_order_with_x,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
        map_unknown_to_x=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51
52
    features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
    features["domain_name"] = np.array(
        [description.encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
55
56
57
    features["residue_index"] = np.array(range(num_res), dtype=np.int32)
    features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
    features["sequence"] = np.array(
        [sequence.encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
60
61
62
    )
    return features


def make_mmcif_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
    mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64
65
) -> FeatureDict:
    input_sequence = mmcif_object.chain_to_seqres[chain_id]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
    description = "_".join([mmcif_object.file_id, chain_id])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
    num_res = len(input_sequence)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
70
    mmcif_feats = {}

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
72
73
74
75
76
77
78
    mmcif_feats.update(
        make_sequence_features(
            sequence=input_sequence,
            description=description,
            num_res=num_res,
        )
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
80
81
82
83
    all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
        mmcif_object=mmcif_object, chain_id=chain_id
    )
    mmcif_feats["all_atom_positions"] = all_atom_positions
    mmcif_feats["all_atom_mask"] = all_atom_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
    mmcif_feats["resolution"] = np.array(
        [mmcif_object.header["resolution"]], dtype=np.float32
    )

    mmcif_feats["release_date"] = np.array(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
        [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
91
92
    )

93
94
    mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
97
    return mmcif_feats


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def make_pdb_features(
        protein_object: protein.Protein, 
        description: str, 
        confidence_threshold: float = 0.5,
) -> FeatureDict:
    pdb_feats = {}

    pdb_feats.update(
        make_sequence_features(
            sequence=protein_object.aatype,
            description=description,
            num_res=len(protein_object.aatype),
        )
    )

    all_atom_positions = protein_object.atom_positions
    all_atom_mask = protein_object.atom_mask

    high_confidence = protein.b_factors > confidence_threshold
    high_confidence = np.any(high_confidence, axis=-1)
    for i, confident in enumerate(high_confidence):
        if(not confident):
            all_atom_mask[i] = 0

    pdb_feats["all_atom_positions"] = all_atom_positions
    pdb_feats["all_atom_mask"] = all_atom_mask

125
    pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
126
127
128
129
130
    pdb_feats["is_distillation"] = np.array(1.).astype(np.float32)

    return pdb_feats


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131
def make_msa_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
133
134
    msas: Sequence[Sequence[str]],
    deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135
136
    """Constructs a feature dict of MSA features."""
    if not msas:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
        raise ValueError("At least one MSA must be provided.")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
138
139
140
141
142
143

    int_msa = []
    deletion_matrix = []
    seen_sequences = set()
    for msa_index, msa in enumerate(msas):
        if not msa:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
146
            raise ValueError(
                f"MSA {msa_index} must contain at least one sequence."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
147
148
149
150
151
152
153
154
155
156
157
158
        for sequence_index, sequence in enumerate(msa):
            if sequence in seen_sequences:
                continue
            seen_sequences.add(sequence)
            int_msa.append(
                [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
            )
            deletion_matrix.append(deletion_matrices[msa_index][sequence_index])

    num_res = len(msas[0][0])
    num_alignments = len(int_msa)
    features = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159
160
161
    features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
    features["msa"] = np.array(int_msa, dtype=np.int32)
    features["num_alignments"] = np.array(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
163
164
165
166
167
        [num_alignments] * num_res, dtype=np.int32
    )
    return features


class AlignmentRunner:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    """Runs alignment tools and saves the results"""

    def __init__(
        self,
        jackhmmer_binary_path: str,
        hhblits_binary_path: str,
        hhsearch_binary_path: str,
        uniref90_database_path: str,
        mgnify_database_path: str,
        bfd_database_path: Optional[str],
        uniclust30_database_path: Optional[str],
        small_bfd_database_path: Optional[str],
        pdb70_database_path: str,
        use_small_bfd: bool,
        no_cpus: int,
        uniref_max_hits: int = 10000,
        mgnify_max_hits: int = 5000,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    ):
        self._use_small_bfd = use_small_bfd
        self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
            binary_path=jackhmmer_binary_path,
            database_path=uniref90_database_path,
            n_cpu=no_cpus,
        )

        if use_small_bfd:
            self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
                binary_path=jackhmmer_binary_path,
                database_path=small_bfd_database_path,
                n_cpu=no_cpus,
            )
        else:
            self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
                binary_path=hhblits_binary_path,
                databases=[bfd_database_path, uniclust30_database_path],
                n_cpu=no_cpus,
            )

        self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
            binary_path=jackhmmer_binary_path,
            database_path=mgnify_database_path,
            n_cpu=no_cpus,
        )

        self.hhsearch_pdb70_runner = hhsearch.HHSearch(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
            binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214
215
216
217
        )
        self.uniref_max_hits = uniref_max_hits
        self.mgnify_max_hits = mgnify_max_hits

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218
219
    def run(
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
221
222
223
        fasta_path: str,
        output_dir: str,
    ):
        """Runs alignment tools on a sequence"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
        jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
            fasta_path
        )[0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
        uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
            jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
231
        uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
        with open(uniref90_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
233
            f.write(uniref90_msa_as_a3m)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234
235
236
        jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
            fasta_path
        )[0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237
        mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
            jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
241
        mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
        with open(mgnify_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
244
            f.write(mgnify_msa_as_a3m)

        hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
246
        pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
        with open(pdb70_out_path, "w") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
248
249
            f.write(hhsearch_result)

        if self._use_small_bfd:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
251
252
253
254
255
            jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
                fasta_path
            )[0]
            bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
            with open(bfd_out_path, "w") as f:
                f.write(jackhmmer_small_bfd_result["sto"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
        else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
258
259
260
261
262
263
            hhblits_bfd_uniclust_result = (
                self.hhblits_bfd_uniclust_runner.query(fasta_path)
            )
            if output_dir is not None:
                bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
                with open(bfd_out_path, "w") as f:
                    f.write(hhblits_bfd_uniclust_result["a3m"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264
265
266
267


class DataPipeline:
    """Assembles input features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
268
269
270
271

    def __init__(
        self,
        template_featurizer: templates.TemplateHitFeaturizer,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
272
273
274
    ):
        self.template_featurizer = template_featurizer

275
    def _parse_msa_data(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
276
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278
        alignment_dir: str,
    ) -> Mapping[str, Any]:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        msa_data = {}
        for f in os.listdir(alignment_dir):
            path = os.path.join(alignment_dir, f)
            ext = os.path.splitext(f)[-1]

            if(ext == ".a3m"):
                with open(path, "r") as fp:
                    msa, deletion_matrix = parsers.parse_a3m(fp.read())
                data = {"msa": msa, "deletion_matrix": deletion_matrix}
            elif(ext == ".sto"):
                with open(path, "r") as fp:
                    msa, deletion_matrix, _ = parsers.parse_stockholm(
                        fp.read()
                    )
                data = {"msa": msa, "deletion_matrix": deletion_matrix}
            else:
                continue
            
            msa_data[f] = data
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
298

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        return msa_data

    def _parse_template_hits(
        self,
        alignment_dir: str,
    ) -> Mapping[str, Any]:
        all_hits = {}
        for f in os.listdir(alignment_dir):
            path = os.path.join(alignment_dir, f)
            ext = os.path.splitext(f)[-1]

            if(ext == ".hhr"):
                with open(path, "r") as fp:
                    hits = parsers.parse_hhr(fp.read())
                all_hits[f] = hits

        return all_hits

    def _process_msa_feats(
        self,
        alignment_dir: str,
    ) -> Mapping[str, Any]:
        msa_data = self._parse_msa_data(alignment_dir)
        msas, deletion_matrices = zip(*[
            (v["msa"], v["deletion_matrix"]) for v in msa_data.values()
        ])
        msa_features = make_msa_features(
            msas=msas,
            deletion_matrices=deletion_matrices,
        )

        return msa_features
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
333
    def process_fasta(
        self,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
335
336
337
338
        fasta_path: str,
        alignment_dir: str,
    ) -> FeatureDict:
        """Assembles features for a single sequence in a FASTA file"""
        with open(fasta_path) as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
339
            fasta_str = f.read()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
340
341
        input_seqs, input_descs = parsers.parse_fasta(fasta_str)
        if len(input_seqs) != 1:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
342
343
344
            raise ValueError(
                f"More than one input sequence found in {fasta_path}."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345
346
347
348
        input_sequence = input_seqs[0]
        input_description = input_descs[0]
        num_res = len(input_sequence)

349
350
        hits = self._parse_template_hits(alignment_dir)
        hits_cat = sum(hits.values(), [])
351
        if(len(hits_cat) == 0):
352
            template_features = empty_template_feats(len(input_sequence))
353
354
355
356
357
358
359
360
        else:
            templates_result = self.template_featurizer.get_templates(
                query_sequence=input_sequence,
                query_pdb_code=None,
                query_release_date=None,
                hits=hits_cat,
            )
            template_features = templates_result.features
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361

362
363
364
365
366
            # The template featurizer doesn't format empty template features
            # properly. This is a quick fix.
            if(template_features["template_aatype"].shape[0] == 0):
                template_features = empty_template_feats(len(input_sequence))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
367
368
369
        sequence_features = make_sequence_features(
            sequence=input_sequence,
            description=input_description,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370
            num_res=num_res,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
        )

373
374
        msa_features = self._process_msa_feats(alignment_dir)
        
375
376
377
        return {
            **sequence_features,
            **msa_features, 
378
            **template_features
379
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
382
383
    def process_mmcif(
        self,
        mmcif: mmcif_parsing.MmcifObject,  # parsing is expensive, so no path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
384
385
386
387
        alignment_dir: str,
        chain_id: Optional[str] = None,
    ) -> FeatureDict:
        """
388
            Assembles features for a specific chain in an mmCIF object.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389

390
391
            If chain_id is None, it is assumed that there is only one chain
            in the object. Otherwise, a ValueError is thrown.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
393
        if chain_id is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
395
            chains = mmcif.structure.get_chains()
            chain = next(chains, None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
396
397
            if chain is None:
                raise ValueError("No chains in mmCIF file")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
399
400
401
402
            chain_id = chain.id

        mmcif_feats = make_mmcif_features(mmcif, chain_id)

        input_sequence = mmcif.chain_to_seqres[chain_id]
403
404
        hits = self._parse_template_hits(alignment_dir)
        hits_cat = sum(hits.values(), [])
405
        print(len(hits_cat))
406
        if(len(hits_cat) == 0):
407
            template_features = empty_template_feats(len(input_sequence))
408
409
410
411
412
413
414
415
        else:
            templates_result = self.template_featurizer.get_templates(
                query_sequence=input_sequence,
                query_pdb_code=None,
                query_release_date=to_date(mmcif.header["release_date"]),
                hits=hits_cat,
            )
            template_features = templates_result.features
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416

417
418
419
420
421
422
            # The template featurizer doesn't format empty template features
            # properly. This is a quick fix.
            if(template_features["template_aatype"].shape[0] == 0):
                template_features = empty_template_feats(len(input_sequence))


423
        msa_features = self._process_msa_feats(alignment_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424

425
        return {**mmcif_feats, **template_features, **msa_features}
426
427
428
429
430
431
432
433
434
435
436
437
438

    def process_pdb(
        self,
        pdb_path: str,
        alignment_dir: str,
    ) -> FeatureDict:
        """
            Assembles features for a protein in a PDB file.
        """
        with open(pdb_path, 'r') as f:
            pdb_str = pdb_path

        protein_object = protein.from_pdb_string(pdb_str)
439
        input_sequence = protein_object.aatype 
440
441
442

        pdb_feats = make_pdb_features(protein_object)

443
444
        hits = self._parse_template_hits(alignment_dir)
        hits_cat = sum(hits.values(), [])
445
        if(len(hits_cat) == 0):
446
            template_features = empty_template_feats(len(input_sequence))
447
448
449
450
451
452
453
454
        else:
            templates_result = self.template_featurizer.get_templates(
                query_sequence=input_sequence,
                query_pdb_code=None,
                query_release_date=None,
                hits=hits_cat,
            )
            template_features = templates_result.features
455

456
457
458
459
460
            # The template featurizer doesn't format empty template features
            # properly. This is a quick fix.
            if(template_features["template_aatype"].shape[0] == 0):
                template_features = empty_template_feats(len(input_sequence))

461
        msa_features = self._process_msa_feats(alignment_dir)
462

463
        return {**pdb_feats, **template_features, **msa_features}