test_multimer_datamodule.py 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
Geoffrey Yu's avatar
Geoffrey Yu committed
16
import shutil
17
18
19
20
import pickle
import torch
import torch.nn as nn
import numpy as np
Geoffrey Yu's avatar
Geoffrey Yu committed
21
from functools import partial
22
import unittest
Geoffrey Yu's avatar
Geoffrey Yu committed
23
from openfold.utils.tensor_utils import tensor_tree_map
24
from openfold.config import model_config
Geoffrey Yu's avatar
Geoffrey Yu committed
25
from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule
Geoffrey Yu's avatar
Geoffrey Yu committed
26
27
28
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts
29
30
31
32
import logging
logger = logging.getLogger(__name__)
import os

Geoffrey Yu's avatar
Geoffrey Yu committed
33
class TestMultimerDataModule(unittest.TestCase):
34
35
36
37
38
39
40
41
42
43
    def setUp(self):
        """
        Set up model config

        use model_1_multimer_v3 for now
        """
        self.config = model_config(
        "model_1_multimer_v3", 
        train=True, 
        low_prec=True)
Geoffrey Yu's avatar
Geoffrey Yu committed
44
        self.data_module = OpenFoldMultimerDataModule(
45
46
        config=self.config.data, 
        batch_seed=42,
Geoffrey Yu's avatar
Geoffrey Yu committed
47
        train_epoch_len=100,
Geoffrey Yu's avatar
Geoffrey Yu committed
48
49
50
51
        template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
        template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
        max_template_date="2500-01-01",
        train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
Geoffrey Yu's avatar
Geoffrey Yu committed
52
        train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/"),
Geoffrey Yu's avatar
Geoffrey Yu committed
53
        kalign_binary_path=shutil.which('kalign'),
Geoffrey Yu's avatar
Geoffrey Yu committed
54
55
56
57
        train_mmcif_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_mmcifs_cache.json"),
        train_chain_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_chain_data_cache.json"),
Geoffrey Yu's avatar
Geoffrey Yu committed
58
    )
Geoffrey Yu's avatar
Geoffrey Yu committed
59
60
        # setup model
        self.c = model_config(consts.model, train=True)
61
        
Geoffrey Yu's avatar
Geoffrey Yu committed
62
63
64
65
66
67
68
        self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
        self.c.model.evoformer_stack.no_blocks = 4  # no need to go overboard here
        self.c.model.evoformer_stack.blocks_per_ckpt = None  # don't want to set up
        # deepspeed for this test
        self.model = AlphaFold(self.c)
        self.multimer_loss = AlphaFoldMultimerLoss(self.c.loss)

Geoffrey Yu's avatar
Geoffrey Yu committed
69
70
    def testPrepareData(self):
        self.data_module.prepare_data()
Geoffrey Yu's avatar
Geoffrey Yu committed
71
72
        self.data_module.setup()
        train_dataset = self.data_module.train_dataset
Dingquan Yu's avatar
Dingquan Yu committed
73
        all_chain_features,ground_truth = train_dataset[1]
Geoffrey Yu's avatar
Geoffrey Yu committed
74
75
76
77
78
79
80
        add_batch_size_dimension = lambda t: (
            t.unsqueeze(0)
        )
        all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
        with torch.no_grad():
            out = self.model(all_chain_features)
            self.multimer_loss(out,(all_chain_features,ground_truth))