test_multimer_datamodule.py 3.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Geoffrey Yu's avatar
Geoffrey Yu committed
16
import shutil
17
18
import torch
import unittest
Geoffrey Yu's avatar
Geoffrey Yu committed
19
from openfold.utils.tensor_utils import tensor_tree_map
20
from openfold.config import model_config
21
from openfold.data.data_modules import OpenFoldMultimerDataModule
Geoffrey Yu's avatar
Geoffrey Yu committed
22
23
24
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts
25
26
27
import logging
logger = logging.getLogger(__name__)

28
29

@unittest.skipIf(not consts.is_multimer or consts.template_mmcif_dir is None, "Template mmcif dir required.")
Geoffrey Yu's avatar
Geoffrey Yu committed
30
class TestMultimerDataModule(unittest.TestCase):
31
32
33
34
35
36
37
    def setUp(self):
        """
        Set up model config

        use model_1_multimer_v3 for now
        """
        self.config = model_config(
38
        consts.model,
39
40
        train=True, 
        low_prec=True)
Geoffrey Yu's avatar
Geoffrey Yu committed
41
        self.data_module = OpenFoldMultimerDataModule(
42
43
        config=self.config.data, 
        batch_seed=42,
Geoffrey Yu's avatar
Geoffrey Yu committed
44
        train_epoch_len=100,
45
        template_mmcif_dir= consts.template_mmcif_dir,
Geoffrey Yu's avatar
Geoffrey Yu committed
46
47
48
        template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
        max_template_date="2500-01-01",
        train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
Geoffrey Yu's avatar
Geoffrey Yu committed
49
        train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/"),
Geoffrey Yu's avatar
Geoffrey Yu committed
50
        kalign_binary_path=shutil.which('kalign'),
Geoffrey Yu's avatar
Geoffrey Yu committed
51
52
53
54
        train_mmcif_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_mmcifs_cache.json"),
        train_chain_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_chain_data_cache.json"),
Geoffrey Yu's avatar
Geoffrey Yu committed
55
    )
Geoffrey Yu's avatar
Geoffrey Yu committed
56
57
        # setup model
        self.c = model_config(consts.model, train=True)
58
        
Geoffrey Yu's avatar
Geoffrey Yu committed
59
60
61
62
63
64
65
        self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
        self.c.model.evoformer_stack.no_blocks = 4  # no need to go overboard here
        self.c.model.evoformer_stack.blocks_per_ckpt = None  # don't want to set up
        # deepspeed for this test
        self.model = AlphaFold(self.c)
        self.multimer_loss = AlphaFoldMultimerLoss(self.c.loss)

Geoffrey Yu's avatar
Geoffrey Yu committed
66
67
    def testPrepareData(self):
        self.data_module.prepare_data()
Geoffrey Yu's avatar
Geoffrey Yu committed
68
69
        self.data_module.setup()
        train_dataset = self.data_module.train_dataset
Dingquan Yu's avatar
Dingquan Yu committed
70
        all_chain_features,ground_truth = train_dataset[1]
Geoffrey Yu's avatar
Geoffrey Yu committed
71
72
73
74
75
76
77
        add_batch_size_dimension = lambda t: (
            t.unsqueeze(0)
        )
        all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
        with torch.no_grad():
            out = self.model(all_chain_features)
            self.multimer_loss(out,(all_chain_features,ground_truth))