Commit 651949b2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move config script, reformat data_transforms

parent f4150fa1
......@@ -5,7 +5,7 @@ import numpy as np
import torch
from operator import add
from config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants
MSA_FEATURE_NAMES = [
......@@ -29,7 +29,9 @@ def make_seq_mask(protein):
return protein
def make_template_mask(protein):
protein['template_mask'] = torch.ones(protein['template_aatype'].shape[0], dtype=torch.float32)
protein['template_mask'] = torch.ones(
protein['template_aatype'].shape[0], dtype=torch.float32
)
return protein
def curry1(f):
......@@ -42,7 +44,9 @@ def curry1(f):
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = torch.tensor(float(distillation), dtype=torch.float32)
protein['is_distillation'] = torch.tensor(
float(distillation), dtype=torch.float32
)
return protein
def make_all_atom_aatype(protein):
......@@ -55,14 +59,20 @@ def fix_templates_aatype(protein):
protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1)
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int32).expand(num_templates, -1)
protein['template_aatype'] = torch.gather(new_order, 1, index=protein['template_aatype'])
new_order = torch.tensor(
new_order_list, dtype=torch.int32
).expand(num_templates, -1)
protein['template_aatype'] = torch.gather(
new_order, 1, index=protein['template_aatype']
)
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor([new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype).transpose(0,1)
new_order = torch.tensor(
[new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype
).transpose(0,1)
protein['msa'] = torch.gather(new_order, 0, protein['msa'])
perm_matrix = np.zeros((22, 22), dtype=np.float32)
......@@ -94,7 +104,9 @@ def squeeze_features(protein):
return protein
def make_protein_crop_to_size_seed(protein):
protein['random_crop_to_size_seed'] = torch.distributions.Uniform(low=torch.int32, high=torch.int32).sample((2))
protein['random_crop_to_size_seed'] = torch.distributions.Uniform(
low=torch.int32, high=torch.int32).sample((2)
)
return protein
@curry1
......@@ -110,8 +122,10 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
torch.rand(protein['aatype'].shape) < replace_proportion
)
protein['aatype'] = torch.where(aatype_mask, torch.ones_like(protein['aatype']) * x_idx,
protein['aatype'])
protein['aatype'] = torch.where(
aatype_mask, torch.ones_like(protein['aatype']) * x_idx,
protein['aatype']
)
return protein
@curry1
......@@ -151,7 +165,11 @@ def delete_extra_msa(protein):
@curry1
def block_delete_msa(protein, config):
num_seq = protein['msa'].shape[0]
block_num_seq = torch.floor(torch.tensor(num_seq, dtype=torch.float32) * config.msa_fraction_per_block).to(torch.int32)
block_num_seq = torch.floor(
torch.tensor(
num_seq, dtype=torch.float32
) * config.msa_fraction_per_block
).to(torch.int32)
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample()
......@@ -195,9 +213,12 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]),
torch.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).transpose(0, 1),
)
agreement = torch.matmul(
torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]),
torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).to(torch.int64)
......@@ -213,14 +234,18 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
# segment_ids.shape should be a prefix of data.shape
assert all([i in data.shape for i in segment_ids.shape])
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:]
)
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
# data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
......@@ -232,7 +257,9 @@ def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein['msa'].shape[0]
def csum(x):
return unsorted_segment_sum(x, protein['extra_cluster_assignment'], num_seq)
return unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq
)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
......@@ -292,7 +319,9 @@ def add_constant_field(protein, key, value):
def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape
num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(torch.reshape(probs+epsilon,[-1, num_classes]))
distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs+epsilon,[-1, num_classes])
)
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
......@@ -323,7 +352,9 @@ def make_masked_msa(protein, config, replace_fraction):
pad_shapes[1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = torch.nn.functional.pad(categorical_probs, pad_shapes, value=mask_prob)
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
sh = protein['msa'].shape
mask_position = torch.rand(sh) < replace_fraction
......@@ -339,7 +370,14 @@ def make_masked_msa(protein, config, replace_fraction):
return protein
@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num_res=0, num_templates=0):
def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = {
......@@ -355,9 +393,13 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
continue
shape = list(v.shape)
schema = shape_schema[k]
msd = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}')
pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)]
f'{msg} {k}: {shape} vs {schema}'
)
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
padding.reverse()
......@@ -371,8 +413,11 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for compatibility with domain datasets.
has_break = torch.clip(protein['between_segment_residues'].to(torch.float32), 0, 1)
# Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets.
has_break = torch.clip(
protein['between_segment_residues'].to(torch.float32), 0, 1
)
aatype_1hot = make_one_hot(protein['aatype'], 21)
target_feat = [
......@@ -391,14 +436,20 @@ def make_msa_feat(protein):
]
if 'cluster_profile' in protein:
deletion_mean_value = (torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
deletion_mean_value = (
torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)
)
msa_feat.extend([protein['cluster_profile'],
torch.unsqueeze(deletion_mean_value, dim=-1),
])
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = torch.clip(protein['extra_deletion_matrix'], 0., 1.)
protein['extra_deletion_value'] = torch.atan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
protein['extra_has_deletion'] = torch.clip(
protein['extra_deletion_matrix'], 0., 1.
)
protein['extra_deletion_value'] = torch.atan(
protein['extra_deletion_matrix'] / 3.
) * (2. / np.pi)
protein['msa_feat'] = torch.cat(msa_feat, dim=-1)
protein['target_feat'] = torch.cat(target_feat, dim=-1)
......@@ -422,35 +473,51 @@ def make_atom14_masks(protein):
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]]
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0) for name in atom_names
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to tell which atom is there in this residue
# Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_to_atom37 = torch.tensor(restype_atom14_to_atom37, dtype=torch.int32)
restype_atom37_to_atom14 = torch.tensor(restype_atom37_to_atom14, dtype=torch.int32)
restype_atom14_mask = torch.tensor(restype_atom14_mask, dtype=torch.float32)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37, dtype=torch.int32
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14, dtype=torch.int32
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask, dtype=torch.float32
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = torch.index_select(restype_atom14_to_atom37, 0, protein['aatype'])
residx_atom14_mask = torch.index_select(restype_atom14_mask, 0, protein['aatype'])
residx_atom14_to_atom37 = torch.index_select(
restype_atom14_to_atom37, 0, protein['aatype']
)
residx_atom14_mask = torch.index_select(
restype_atom14_mask, 0, protein['aatype']
)
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = torch.index_select(restype_atom37_to_atom14, 0, protein['aatype'])
residx_atom37_to_atom14 = torch.index_select(
restype_atom37_to_atom14, 0, protein['aatype']
)
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
......@@ -462,7 +529,9 @@ def make_atom14_masks(protein):
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = torch.index_select(restype_atom37_mask, 0, protein['aatype'])
residx_atom37_mask = torch.index_select(
restype_atom37_mask, 0, protein['aatype']
)
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
\ No newline at end of file
return protein
......@@ -31,7 +31,7 @@ import time
import numpy as np
import torch
from config import model_config
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
......
......@@ -6,7 +6,7 @@ import unittest
import numpy as np
from config import model_config
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
......
......@@ -16,7 +16,7 @@ import torch
import numpy as np
import unittest
from config import model_config
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
......
......@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
import numpy as np
import unittest
from config import *
from openfold.config import model_config
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment