test_multimer_datamodule.py 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Geoffrey Yu's avatar
Geoffrey Yu committed
16
import shutil
17
18
import torch
import unittest
Geoffrey Yu's avatar
Geoffrey Yu committed
19
from openfold.utils.tensor_utils import tensor_tree_map
20
from openfold.config import model_config
21
from openfold.data.data_modules import OpenFoldMultimerDataModule
Geoffrey Yu's avatar
Geoffrey Yu committed
22
from openfold.model.model import AlphaFold
23
24
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
Geoffrey Yu's avatar
Geoffrey Yu committed
25
from tests.config import consts
26
27
28
import logging
logger = logging.getLogger(__name__)

29
30

@unittest.skipIf(not consts.is_multimer or consts.template_mmcif_dir is None, "Template mmcif dir required.")
Geoffrey Yu's avatar
Geoffrey Yu committed
31
class TestMultimerDataModule(unittest.TestCase):
32
33
34
35
36
37
38
    def setUp(self):
        """
        Set up model config

        use model_1_multimer_v3 for now
        """
        self.config = model_config(
39
        consts.model,
40
41
        train=True, 
        low_prec=True)
Geoffrey Yu's avatar
Geoffrey Yu committed
42
        self.data_module = OpenFoldMultimerDataModule(
43
44
        config=self.config.data, 
        batch_seed=42,
Geoffrey Yu's avatar
Geoffrey Yu committed
45
        train_epoch_len=100,
46
        template_mmcif_dir= consts.template_mmcif_dir,
Geoffrey Yu's avatar
Geoffrey Yu committed
47
48
49
        template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
        max_template_date="2500-01-01",
        train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
Geoffrey Yu's avatar
Geoffrey Yu committed
50
        train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/"),
Geoffrey Yu's avatar
Geoffrey Yu committed
51
        kalign_binary_path=shutil.which('kalign'),
Geoffrey Yu's avatar
Geoffrey Yu committed
52
53
54
55
        train_mmcif_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_mmcifs_cache.json"),
        train_chain_data_cache_path=os.path.join(os.getcwd(),
                                                 "tests/test_data/train_chain_data_cache.json"),
Geoffrey Yu's avatar
Geoffrey Yu committed
56
    )
Geoffrey Yu's avatar
Geoffrey Yu committed
57
58
        # setup model
        self.c = model_config(consts.model, train=True)
59
        
Geoffrey Yu's avatar
Geoffrey Yu committed
60
61
62
63
64
        self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
        self.c.model.evoformer_stack.no_blocks = 4  # no need to go overboard here
        self.c.model.evoformer_stack.blocks_per_ckpt = None  # don't want to set up
        # deepspeed for this test
        self.model = AlphaFold(self.c)
65
        self.loss = AlphaFoldLoss(self.c.loss)
Geoffrey Yu's avatar
Geoffrey Yu committed
66

Geoffrey Yu's avatar
Geoffrey Yu committed
67
68
    def testPrepareData(self):
        self.data_module.prepare_data()
Geoffrey Yu's avatar
Geoffrey Yu committed
69
70
        self.data_module.setup()
        train_dataset = self.data_module.train_dataset
71
        all_chain_features = train_dataset[1]
Geoffrey Yu's avatar
Geoffrey Yu committed
72
73
74
        add_batch_size_dimension = lambda t: (
            t.unsqueeze(0)
        )
75
        all_chain_features = tensor_tree_map(add_batch_size_dimension, all_chain_features)
Geoffrey Yu's avatar
Geoffrey Yu committed
76
        with torch.no_grad():
77
78
79
            ground_truth = all_chain_features.pop('gt_features', None)

            # Run the model
Geoffrey Yu's avatar
Geoffrey Yu committed
80
            out = self.model(all_chain_features)
81
82
83
84
85
86
87
88
89

            # Remove the recycling dimension
            all_chain_features = tensor_tree_map(lambda t: t[..., -1], all_chain_features)

            all_chain_features = multi_chain_permutation_align(out=out,
                                                               features=all_chain_features,
                                                               ground_truth=ground_truth)

            self.loss(out, all_chain_features)