Commit 651949b2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move config script, reformat data_transforms

parent f4150fa1
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from operator import add from operator import add
from config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants from openfold.np import residue_constants
MSA_FEATURE_NAMES = [ MSA_FEATURE_NAMES = [
...@@ -29,7 +29,9 @@ def make_seq_mask(protein): ...@@ -29,7 +29,9 @@ def make_seq_mask(protein):
return protein return protein
def make_template_mask(protein): def make_template_mask(protein):
protein['template_mask'] = torch.ones(protein['template_aatype'].shape[0], dtype=torch.float32) protein['template_mask'] = torch.ones(
protein['template_aatype'].shape[0], dtype=torch.float32
)
return protein return protein
def curry1(f): def curry1(f):
...@@ -42,7 +44,9 @@ def curry1(f): ...@@ -42,7 +44,9 @@ def curry1(f):
@curry1 @curry1
def add_distillation_flag(protein, distillation): def add_distillation_flag(protein, distillation):
protein['is_distillation'] = torch.tensor(float(distillation), dtype=torch.float32) protein['is_distillation'] = torch.tensor(
float(distillation), dtype=torch.float32
)
return protein return protein
def make_all_atom_aatype(protein): def make_all_atom_aatype(protein):
...@@ -55,14 +59,20 @@ def fix_templates_aatype(protein): ...@@ -55,14 +59,20 @@ def fix_templates_aatype(protein):
protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1) protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1)
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int32).expand(num_templates, -1) new_order = torch.tensor(
protein['template_aatype'] = torch.gather(new_order, 1, index=protein['template_aatype']) new_order_list, dtype=torch.int32
).expand(num_templates, -1)
protein['template_aatype'] = torch.gather(
new_order, 1, index=protein['template_aatype']
)
return protein return protein
def correct_msa_restypes(protein): def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants.""" """Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor([new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype).transpose(0,1) new_order = torch.tensor(
[new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype
).transpose(0,1)
protein['msa'] = torch.gather(new_order, 0, protein['msa']) protein['msa'] = torch.gather(new_order, 0, protein['msa'])
perm_matrix = np.zeros((22, 22), dtype=np.float32) perm_matrix = np.zeros((22, 22), dtype=np.float32)
...@@ -94,7 +104,9 @@ def squeeze_features(protein): ...@@ -94,7 +104,9 @@ def squeeze_features(protein):
return protein return protein
def make_protein_crop_to_size_seed(protein): def make_protein_crop_to_size_seed(protein):
protein['random_crop_to_size_seed'] = torch.distributions.Uniform(low=torch.int32, high=torch.int32).sample((2)) protein['random_crop_to_size_seed'] = torch.distributions.Uniform(
low=torch.int32, high=torch.int32).sample((2)
)
return protein return protein
@curry1 @curry1
...@@ -110,8 +122,10 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -110,8 +122,10 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
torch.rand(protein['aatype'].shape) < replace_proportion torch.rand(protein['aatype'].shape) < replace_proportion
) )
protein['aatype'] = torch.where(aatype_mask, torch.ones_like(protein['aatype']) * x_idx, protein['aatype'] = torch.where(
protein['aatype']) aatype_mask, torch.ones_like(protein['aatype']) * x_idx,
protein['aatype']
)
return protein return protein
@curry1 @curry1
...@@ -151,7 +165,11 @@ def delete_extra_msa(protein): ...@@ -151,7 +165,11 @@ def delete_extra_msa(protein):
@curry1 @curry1
def block_delete_msa(protein, config): def block_delete_msa(protein, config):
num_seq = protein['msa'].shape[0] num_seq = protein['msa'].shape[0]
block_num_seq = torch.floor(torch.tensor(num_seq, dtype=torch.float32) * config.msa_fraction_per_block).to(torch.int32) block_num_seq = torch.floor(
torch.tensor(
num_seq, dtype=torch.float32
) * config.msa_fraction_per_block
).to(torch.int32)
if config.randomize_num_blocks: if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample() nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample()
...@@ -195,8 +213,11 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): ...@@ -195,8 +213,11 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup. # in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]), agreement = torch.matmul(
torch.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).transpose(0, 1), torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]),
torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1),
) )
# Assign each sequence in the extra sequences to the closest MSA sample # Assign each sequence in the extra sequences to the closest MSA sample
...@@ -213,14 +234,18 @@ def unsorted_segment_sum(data, segment_ids, num_segments): ...@@ -213,14 +234,18 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
:param num_segments: The number of segments. :param num_segments: The number of segments.
:return: A tensor of same data type as the data argument. :return: A tensor of same data type as the data argument.
""" """
assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape" # segment_ids.shape should be a prefix of data.shape
assert all([i in data.shape for i in segment_ids.shape])
# segment_ids is a 1-D tensor repeat it to have the same shape as data # segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1: if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long() s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:]
)
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" # data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
shape = [num_segments] + list(data.shape[1:]) shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float()) tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
...@@ -232,7 +257,9 @@ def summarize_clusters(protein): ...@@ -232,7 +257,9 @@ def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster.""" """Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein['msa'].shape[0] num_seq = protein['msa'].shape[0]
def csum(x): def csum(x):
return unsorted_segment_sum(x, protein['extra_cluster_assignment'], num_seq) return unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq
)
mask = protein['extra_msa_mask'] mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
...@@ -292,7 +319,9 @@ def add_constant_field(protein, key, value): ...@@ -292,7 +319,9 @@ def add_constant_field(protein, key, value):
def shaped_categorical(probs, epsilon=1e-10): def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape ds = probs.shape
num_classes = ds[-1] num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(torch.reshape(probs+epsilon,[-1, num_classes])) distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs+epsilon,[-1, num_classes])
)
counts = distribution.sample() counts = distribution.sample()
return torch.reshape(counts, ds[:-1]) return torch.reshape(counts, ds[:-1])
...@@ -323,7 +352,9 @@ def make_masked_msa(protein, config, replace_fraction): ...@@ -323,7 +352,9 @@ def make_masked_msa(protein, config, replace_fraction):
pad_shapes[1] = 1 pad_shapes[1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0. assert mask_prob >= 0.
categorical_probs = torch.nn.functional.pad(categorical_probs, pad_shapes, value=mask_prob) categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
sh = protein['msa'].shape sh = protein['msa'].shape
mask_position = torch.rand(sh) < replace_fraction mask_position = torch.rand(sh) < replace_fraction
...@@ -339,7 +370,14 @@ def make_masked_msa(protein, config, replace_fraction): ...@@ -339,7 +370,14 @@ def make_masked_msa(protein, config, replace_fraction):
return protein return protein
@curry1 @curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num_res=0, num_templates=0): def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0
):
"""Guess at the MSA and sequence dimension to make fixed size.""" """Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = { pad_size_map = {
...@@ -355,9 +393,13 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num ...@@ -355,9 +393,13 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
continue continue
shape = list(v.shape) shape = list(v.shape)
schema = shape_schema[k] schema = shape_schema[k]
msd = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), ( assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}') f'{msg} {k}: {shape} vs {schema}'
pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] )
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
padding.reverse() padding.reverse()
...@@ -371,8 +413,11 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num ...@@ -371,8 +413,11 @@ def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num
@curry1 @curry1
def make_msa_feat(protein): def make_msa_feat(protein):
"""Create and concatenate MSA features.""" """Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for compatibility with domain datasets. # Whether there is a domain break. Always zero for chains, but keeping for
has_break = torch.clip(protein['between_segment_residues'].to(torch.float32), 0, 1) # compatibility with domain datasets.
has_break = torch.clip(
protein['between_segment_residues'].to(torch.float32), 0, 1
)
aatype_1hot = make_one_hot(protein['aatype'], 21) aatype_1hot = make_one_hot(protein['aatype'], 21)
target_feat = [ target_feat = [
...@@ -391,14 +436,20 @@ def make_msa_feat(protein): ...@@ -391,14 +436,20 @@ def make_msa_feat(protein):
] ]
if 'cluster_profile' in protein: if 'cluster_profile' in protein:
deletion_mean_value = (torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)) deletion_mean_value = (
torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)
)
msa_feat.extend([protein['cluster_profile'], msa_feat.extend([protein['cluster_profile'],
torch.unsqueeze(deletion_mean_value, dim=-1), torch.unsqueeze(deletion_mean_value, dim=-1),
]) ])
if 'extra_deletion_matrix' in protein: if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = torch.clip(protein['extra_deletion_matrix'], 0., 1.) protein['extra_has_deletion'] = torch.clip(
protein['extra_deletion_value'] = torch.atan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi) protein['extra_deletion_matrix'], 0., 1.
)
protein['extra_deletion_value'] = torch.atan(
protein['extra_deletion_matrix'] / 3.
) * (2. / np.pi)
protein['msa_feat'] = torch.cat(msa_feat, dim=-1) protein['msa_feat'] = torch.cat(msa_feat, dim=-1)
protein['target_feat'] = torch.cat(target_feat, dim=-1) protein['target_feat'] = torch.cat(target_feat, dim=-1)
...@@ -422,35 +473,51 @@ def make_atom14_masks(protein): ...@@ -422,35 +473,51 @@ def make_atom14_masks(protein):
restype_atom14_mask = [] restype_atom14_mask = []
for rt in residue_constants.restypes: for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]] atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]
]
restype_atom14_to_atom37.append([ restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0) for name in atom_names (residue_constants.atom_order[name] if name else 0)
for name in atom_names
]) ])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([ restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types for name in residue_constants.atom_types
]) ])
# Since all 14 atoms are not present in every residue, use this mask to tell which atom is there in this residue # Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK' # Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14) restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37) restype_atom37_to_atom14.append([0] * 37)
restype_atom14_to_atom37 = torch.tensor(restype_atom14_to_atom37, dtype=torch.int32) restype_atom14_to_atom37 = torch.tensor(
restype_atom37_to_atom14 = torch.tensor(restype_atom37_to_atom14, dtype=torch.int32) restype_atom14_to_atom37, dtype=torch.int32
restype_atom14_mask = torch.tensor(restype_atom14_mask, dtype=torch.float32) )
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14, dtype=torch.int32
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask, dtype=torch.float32
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array # create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein # with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = torch.index_select(restype_atom14_to_atom37, 0, protein['aatype']) residx_atom14_to_atom37 = torch.index_select(
residx_atom14_mask = torch.index_select(restype_atom14_mask, 0, protein['aatype']) restype_atom14_to_atom37, 0, protein['aatype']
)
residx_atom14_mask = torch.index_select(
restype_atom14_mask, 0, protein['aatype']
)
protein['atom14_atom_exists'] = residx_atom14_mask protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37 protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = torch.index_select(restype_atom37_to_atom14, 0, protein['aatype']) residx_atom37_to_atom14 = torch.index_select(
restype_atom37_to_atom14, 0, protein['aatype']
)
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14 protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask # create the corresponding mask
...@@ -462,7 +529,9 @@ def make_atom14_masks(protein): ...@@ -462,7 +529,9 @@ def make_atom14_masks(protein):
atom_type = residue_constants.atom_order[atom_name] atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = torch.index_select(restype_atom37_mask, 0, protein['aatype']) residx_atom37_mask = torch.index_select(
restype_atom37_mask, 0, protein['aatype']
)
protein['atom37_atom_exists'] = residx_atom37_mask protein['atom37_atom_exists'] = residx_atom37_mask
return protein return protein
...@@ -31,7 +31,7 @@ import time ...@@ -31,7 +31,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
......
...@@ -6,7 +6,7 @@ import unittest ...@@ -6,7 +6,7 @@ import unittest
import numpy as np import numpy as np
from config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts from tests.config import consts
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
import numpy as np import numpy as np
import unittest import unittest
from config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import unittest import unittest
from config import * from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map from openfold.utils.tensor_utils import tree_map, tensor_tree_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment