config.py 12.6 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
import copy
import ml_collections as mlc


5
6
7
8
9
10
11
12
13
def set_inf(c, inf):
    for k, v in c.items():
        if(isinstance(v, mlc.ConfigDict)):
            set_inf(v, inf)
        elif(k == "inf"):
            c[k] = inf


def model_config(name, train=False, low_prec=False):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
    c = copy.deepcopy(config)
15
16
17
18
19
    if(name == "model_1"):
        pass
    elif(name == "model_2"):
        pass
    elif(name == "model_3"):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
21
22
23
24
        c.model.template.enabled = False
    elif(name == "model_4"):
        c.model.template.enabled = False
    elif(name == "model_5"):
        c.model.template.enabled = False
25
26
    elif(name == "model_1_ptm"):
        c.model.heads.tm.enabled = True
27
        c.loss.tm.weight = 0.1
28
29
    elif(name == "model_2_ptm"):
        c.model.heads.tm.enabled = True
30
        c.loss.tm.weight = 0.1
31
32
33
    elif(name == "model_3_ptm"):
        c.model.template.enabled = False
        c.model.heads.tm.enabled = True
34
        c.loss.tm.weight = 0.1
35
36
37
    elif(name == "model_4_ptm"):
        c.model.template.enabled = False
        c.model.heads.tm.enabled = True
38
        c.loss.tm.weight = 0.1
39
40
41
    elif(name == "model_5_ptm"):
        c.model.template.enabled = False
        c.model.heads.tm.enabled = True
42
        c.loss.tm.weight = 0.1
43
44
45
46
    else:
        raise ValueError("Invalid model name")

    if(train):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
        c.globals.blocks_per_ckpt = 1
48
        c.globals.chunk_size = None
49
50
51
52
53
54

    if(low_prec):
        c.globals.eps = 1e-4
        # If we want exact numerical parity with the original, inf can't be
        # a global constant
        set_inf(c, 1e4)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
    
    return c
57

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58

59
60
61
62
63
64
65
66
67
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
68

69
70
71
72
73
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74
config = mlc.ConfigDict({
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    'data': {
        'common': {
            'masked_msa': {
                'profile_prob': 0.1,
                'same_prob': 0.1,
                'uniform_prob': 0.1
            },
            'max_extra_msa': 1024,
            'msa_cluster_features': True,
            'num_recycle': 3,
            'reduce_msa_clusters_by_max_templates': False,
            'resample_msa_in_recycling': True,
            'template_features': [
                'template_all_atom_positions', 'template_sum_probs',
                'template_aatype', 'template_all_atom_masks',
                # 'template_domain_names'
            ],
            'unsupervised_features': [
                'aatype', 'residue_index', 'msa',  # 'sequence', #'domain_name',
                'num_alignments', 'seq_length', 'between_segment_residues',
                'deletion_matrix'
            ],
            'use_templates': True,
        },
        'eval': {
            'feat': {
                'aatype': [NUM_RES],
                'all_atom_mask': [NUM_RES, None],
                'all_atom_positions': [NUM_RES, None, None],
                'alt_chi_angles': [NUM_RES, None],
                'atom14_alt_gt_exists': [NUM_RES, None],
                'atom14_alt_gt_positions': [NUM_RES, None, None],
                'atom14_atom_exists': [NUM_RES, None],
                'atom14_atom_is_ambiguous': [NUM_RES, None],
                'atom14_gt_exists': [NUM_RES, None],
                'atom14_gt_positions': [NUM_RES, None, None],
                'atom37_atom_exists': [NUM_RES, None],
                'backbone_affine_mask': [NUM_RES],
                'backbone_affine_tensor': [NUM_RES, None],
                'bert_mask': [NUM_MSA_SEQ, NUM_RES],
                'chi_angles': [NUM_RES, None],
                'chi_mask': [NUM_RES, None],
                'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_row_mask': [NUM_EXTRA_SEQ],
                'is_distillation': [],
                'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
                'msa_mask': [NUM_MSA_SEQ, NUM_RES],
                'msa_row_mask': [NUM_MSA_SEQ],
                'pseudo_beta': [NUM_RES, None],
                'pseudo_beta_mask': [NUM_RES],
                'random_crop_to_size_seed': [None],
                'residue_index': [NUM_RES],
                'residx_atom14_to_atom37': [NUM_RES, None],
                'residx_atom37_to_atom14': [NUM_RES, None],
                'resolution': [],
                'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
                'rigidgroups_group_exists': [NUM_RES, None],
                'rigidgroups_group_is_ambiguous': [NUM_RES, None],
                'rigidgroups_gt_exists': [NUM_RES, None],
                'rigidgroups_gt_frames': [NUM_RES, None, None],
                'seq_length': [],
                'seq_mask': [NUM_RES],
                'target_feat': [NUM_RES, None],
                'template_aatype': [NUM_TEMPLATES, NUM_RES],
                'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
                'template_all_atom_positions': [
                    NUM_TEMPLATES, NUM_RES, None, None],
                'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
                'template_backbone_affine_tensor': [
                    NUM_TEMPLATES, NUM_RES, None],
                'template_mask': [NUM_TEMPLATES],
                'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
                'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
                'template_sum_probs': [NUM_TEMPLATES, None],
                'true_msa': [NUM_MSA_SEQ, NUM_RES]
            },
            'fixed_size': True,
            'subsample_templates': False,  # We want top templates.
            'masked_msa_replace_fraction': 0.15,
            'max_msa_clusters': 512,
            'max_templates': 4,
            'num_ensemble': 1,
        }
    },
162
163
164
165
    # Recurring FieldReferences that can be changed globally here
    "globals": {
        "blocks_per_ckpt": blocks_per_ckpt,
        "chunk_size": chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167
168
169
170
        "c_z": c_z,
        "c_m": c_m,
        "c_t": c_t,
        "c_e": c_e,
        "c_s": c_s,
171
172
173
174
        "eps": eps,
    },
    "model": {
        "no_cycles": 4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        "_mask_trans": False,
        "input_embedder": {
            "tf_dim": 22,
            "msa_dim": 49,
            "c_z": c_z,
            "c_m": c_m,
            "relpos_k": 32,
        },
        "recycling_embedder": {
            "c_z": c_z,
            "c_m": c_m, 
            "min_bin": 3.25,
            "max_bin": 20.75,
            "no_bins": 15,
189
            "inf": 1e8,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        },
        "template": {
            "distogram": {
                "min_bin": 3.25,
                "max_bin": 50.75,
                "no_bins": 39,
            },
            "template_angle_embedder": {
                # DISCREPANCY: c_in is supposed to be 51.
                "c_in": 57,
                "c_out": c_m,
            },
            "template_pair_embedder": {
                "c_in": 88,
                "c_out": c_t,
            },
            "template_pair_stack": {
                "c_t": c_t, 
                # DISCREPANCY: c_hidden_tri_att here is given in the supplement
                # as 64. In the code, it's 16.
                "c_hidden_tri_att": 16, 
                "c_hidden_tri_mul": 64,
                "no_blocks": 2, 
                "no_heads": 4, 
                "pair_transition_n": 2, 
                "dropout_rate": 0.25,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216
                "blocks_per_ckpt": blocks_per_ckpt,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
                "chunk_size": chunk_size,
218
                "inf": 1e5,#1e9,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
220
221
222
223
224
225
226
227
            },
            "template_pointwise_attention": {
                "c_t": c_t, 
                "c_z": c_z, 
                # DISCREPANCY: c_hidden here is given in the supplement as 64.
                # It's actually 16.
                "c_hidden": 16, 
                "no_heads": 4,
                "chunk_size": chunk_size,
228
                "inf": 1e5,#1e9,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
            },
230
            "inf": 1e5,#1e9,
231
            "eps": eps,#1e-6,
232
            "enabled": True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            "embed_angles": True,
        },
        "extra_msa": {
            "extra_msa_embedder": {
                "c_in": 25,
                "c_out": c_e,
            },
            "extra_msa_stack": {
                "c_m": c_e,
                "c_z": c_z,
                "c_hidden_msa_att": 8,
                "c_hidden_opm": 32,
                "c_hidden_mul": 128,
                "c_hidden_pair_att": 32,
                "no_heads_msa": 8,
                "no_heads_pair": 4,
                "no_blocks": 4,
                "transition_n": 4,
                "msa_dropout": 0.15,
                "pair_dropout": 0.25,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253
                "blocks_per_ckpt": blocks_per_ckpt,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
                "chunk_size": chunk_size,
255
                "inf": 1e5,#1e9,
256
                "eps": eps,#1e-10,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
            },
258
            "enabled": True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        },
        "evoformer_stack": {
            "c_m": c_m,
            "c_z": c_z,
            "c_hidden_msa_att": 32,
            "c_hidden_opm": 32,
            "c_hidden_mul": 128,
            "c_hidden_pair_att": 32,
            "c_s": c_s,
            "no_heads_msa": 8,
            "no_heads_pair": 4,
            "no_blocks": 48,
            "transition_n": 4,
            "msa_dropout": 0.15,
            "pair_dropout": 0.25,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
            "blocks_per_ckpt": blocks_per_ckpt,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
            "chunk_size": chunk_size,
276
            "inf": 1e5,#1e9,
277
            "eps": eps,#1e-10,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
279
280
281
282
283
284
285
286
287
        },
        "structure_module": {
            "c_s": c_s, 
            "c_z": c_z,
            "c_ipa": 16,
            "c_resnet": 128,
            "no_heads_ipa": 12,
            "no_qk_points": 4,
            "no_v_points": 8,
            "dropout_rate": 0.1,
288
            "no_blocks": 8,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
290
291
292
            "no_transition_layers": 1,
            "no_resnet_blocks": 2,
            "no_angles": 7,
            "trans_scale_factor": 10,
293
            "epsilon": eps,#1e-12,
294
            "inf": 1e5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
297
298
299
300
301
302
303
304
305
        },
        "heads": {
            "lddt": {
                "no_bins": 50,
                "c_in": c_s,
                "c_hidden": 128,
            },
            "distogram": {
                "c_z": c_z,
                "no_bins": aux_distogram_bins,
            },
306
            "tm": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307
308
                "c_z": c_z,
                "no_bins": aux_distogram_bins,
309
                "enabled": False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            },
            "masked_msa": {
                "c_m": c_m,
                "c_out": 23,
            },
            "experimentally_resolved": {
                "c_s": c_s,
                "c_out": 37,
            },
        },
    },
    "relax": {
        "max_iterations": 0, # no max
        "tolerance": 2.39,
        "stiffness": 10.0,
        "max_outer_iterations": 20,
        "exclude_residues": [],
    },
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328
329
330
331
332
    "loss": {
        "distogram": {
            "min_bin": 2.3125, 
            "max_bin": 21.6875, 
            "no_bins": 64, 
333
            "eps": eps,#1e-6,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
            "weight": 0.3, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
336
        },
        "experimentally_resolved": {
337
            "eps": eps,#1e-8,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
339
340
341
342
343
344
345
346
            "min_resolution": 0.1,
            "max_resolution": 3.0,
            "weight": 0.,
        },
        "fape": {
            "backbone": { 
                "clamp_distance": 10.,
                "loss_unit_distance": 10.,
                "weight": 0.5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
            },
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348
349
350
351
            "sidechain": {
                "clamp_distance": 10.,
                "length_scale": 10.,
                "weight": 0.5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352
            },
353
            "eps": 1e-4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354
355
356
357
358
359
            "weight": 1.0,
        },
        "lddt": {
            "min_resolution": 0.1,
            "max_resolution": 3.0,
            "cutoff": 15.,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
            "no_bins": 50,
361
            "eps": eps,#1e-10,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362
            "weight": 0.01,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
363
364
        },
        "masked_msa": {
365
            "eps": eps,#1e-8,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
            "weight": 2.0,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
367
368
369
370
        },
        "supervised_chi": {
            "chi_weight": 0.5,
            "angle_norm_weight": 0.01,
371
            "eps": eps,#1e-6,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
372
            "weight": 1.0,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373
374
        },
        "violation": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375
376
            "violation_tolerance_factor": 12.0,
            "clash_overlap_tolerance": 1.5,
377
            "eps": eps,#1e-6,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
379
            "weight": 0.,
        },
380
381
382
383
384
385
        "tm": {
            "max_bin": 31,
            "no_bins": 64,
            "min_resolution": 0.1,
            "max_resolution": 3.0,
            "eps": eps,#1e-8,
386
            "weight": 0.,
387
        },
388
        "eps": eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389
    },
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
})