"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "204ed191059a5fadf993a7cab8ca4bd33b744c16"
Commit 49767099 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Bring tests up to speed

parent a6f56d16
......@@ -64,7 +64,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
......@@ -77,7 +76,6 @@ config = mlc.ConfigDict(
{
"data": {
"common": {
"batch_modes": [("clamped", 0.9), ("unclamped", 0.1)],
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
......@@ -93,7 +91,7 @@ config = mlc.ConfigDict(
"backbone_affine_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
......@@ -104,6 +102,7 @@ config = mlc.ConfigDict(
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
......@@ -149,8 +148,8 @@ config = mlc.ConfigDict(
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"num_recycle": num_recycle,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
......@@ -167,9 +166,14 @@ config = mlc.ConfigDict(
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
......@@ -212,6 +216,8 @@ config = mlc.ConfigDict(
"crop": True,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"subsample_recycling": True,
},
"data_module": {
"use_small_bfd": False,
......@@ -234,7 +240,6 @@ config = mlc.ConfigDict(
"eps": eps,
},
"model": {
"num_recycle": num_recycle,
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
......
......@@ -5,6 +5,7 @@ import os
from typing import Optional, Sequence
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
......@@ -216,31 +217,66 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"):
self.config = config
batch_modes = config.common.batch_modes
batch_mode_names, batch_mode_probs = list(zip(*batch_modes))
self.batch_mode_names = batch_mode_names
self.batch_mode_probs = batch_mode_probs
self.generator = generator
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
self._prep_batch_properties_probs()
self.batch_mode_probs_tensor = torch.tensor(self.batch_mode_probs)
def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(self.config.supervised.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(self.config)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def __call__(self, raw_prots):
# We use torch.multinomial here rather than Categorical because the
# latter doesn't accept a generator for some reason
batch_mode_idx = torch.multinomial(
self.batch_mode_probs_tensor,
1,
def _add_batch_properties(self, raw_prots):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
).item()
batch_mode_name = self.batch_mode_names[batch_mode_idx]
)
for i, key in enumerate(self.prop_keys):
sample = samples[i][0]
for prot in raw_prots:
prot[key] = np.array(sample, dtype=np.float32)
def __call__(self, raw_prots):
self._add_batch_properties(raw_prots)
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage, batch_mode_name
prot, self.stage
)
processed_prots.append(features)
......@@ -264,7 +300,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -286,6 +323,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.batch_seed = batch_seed
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
......@@ -309,7 +347,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
def setup(self, stage):
def setup(self, stage: Optional[str] = None):
if(stage is None):
stage = "train"
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
......@@ -369,12 +410,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict",
)
self.batch_collation_seed = torch.Generator().seed()
def _gen_batch_collator(self, stage):
""" We want each process to use the same batch collation seed """
generator = torch.Generator()
generator = generator.manual_seed(self.batch_collation_seed)
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
collate_fn = OpenFoldBatchCollator(
self.config, generator, stage
)
......@@ -404,5 +444,5 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval")
collate_fn=self._gen_batch_collator("predict")
)
......@@ -1095,7 +1095,6 @@ def random_crop_to_size(
shape_schema,
subsample_templates=False,
seed=None,
batch_mode="clamped",
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein["seq_length"]
......@@ -1133,13 +1132,11 @@ def random_crop_to_size(
num_templates_crop_size = num_templates
n = seq_length - num_res_crop_size
if batch_mode == "clamped":
if protein["use_clamped_fape"] == 1.:
right_anchor = n
elif batch_mode == "unclamped":
else:
x = _randint(0, n)
right_anchor = n - x
else:
raise ValueError("Invalid batch mode")
num_res_crop_start = _randint(0, right_anchor)
......
......@@ -64,7 +64,7 @@ def make_data_config(
feature_names += cfg.common.template_features
if cfg[mode].supervised:
feature_names += cfg.common.supervised_features
feature_names += cfg.supervised.supervised_features
return cfg, feature_names
......@@ -73,7 +73,6 @@ def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
batch_mode: str,
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
......@@ -84,11 +83,6 @@ def np_example_to_features(
"deletion_matrix_int"
).astype(np.float32)
if batch_mode == "clamped":
np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
elif batch_mode == "unclamped":
np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
......@@ -97,7 +91,6 @@ def np_example_to_features(
tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
)
return {k: v for k, v in features.items()}
......@@ -115,12 +108,10 @@ class FeaturePipeline:
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
batch_mode: str = "clamped",
mode: str = "train",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
batch_mode=batch_mode,
)
......@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
......@@ -116,7 +116,6 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=ensemble_seed,
)
)
......@@ -137,9 +136,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
return transforms
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode="clamped"
):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
......@@ -150,7 +147,6 @@ def process_tensors_from_config(
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
batch_mode,
ensemble_seed,
)
fn = compose(fns)
......@@ -160,9 +156,11 @@ def process_tensors_from_config(
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
num_ensemble = mode_cfg.num_ensemble
num_recycling = tensors["no_recycling_iters"].item()
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= common_cfg.num_recycle + 1
num_ensemble *= num_recycling + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(
......
......@@ -202,7 +202,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if self.config.num_recycle > 0:
if feats["no_recycling_iters"] > 0:
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
......@@ -236,7 +236,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z = z + z_prev_emb
# This can matter during inference when N_res is very large
# Possibly prevents memory fragmentation
del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
......@@ -395,19 +395,21 @@ class AlphaFold(nn.Module):
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
for cycle_no in range(self.config.num_recycle + 1):
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == self.config.num_recycle
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766
# Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter:
self._enable_activation_checkpointing()
if torch.is_autocast_enabled():
......
......@@ -258,11 +258,14 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K]
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, Q, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, K]
)
a = torch.matmul(q, k)
del q, k
......@@ -273,11 +276,11 @@ class Attention(nn.Module):
a = a + b
a = self.softmax(a)
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
permute_final_dims(v, (1, 0, 2)), # [*, H, V, C_hidden]
)
o = torch.matmul(a, v)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
......
#!/bin/bash
#CUDA_VISIBLE_DEVICES="5"
python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
......@@ -60,7 +60,7 @@ _model = None
def get_global_pretrained_openfold():
global _model
if _model is None:
_model = AlphaFold(model_config("model_1_ptm").model)
_model = AlphaFold(model_config("model_1_ptm"))
_model = _model.eval()
if not os.path.exists(_param_path):
raise FileNotFoundError(
......
......@@ -25,11 +25,17 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_masks": np.random.randint(
"template_all_atom_mask": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3)
* 10,
"template_all_atom_positions":
np.random.rand(*b, n_templ, n, 37, 3) * 10,
"template_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_alt_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_torsion_angles_mask":
np.random.rand(*b, n_templ, n, 7),
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
......
......@@ -66,7 +66,6 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
chunk_size=4,
inf=inf,
eps=eps,
).eval()
......@@ -79,7 +78,9 @@ class TestEvoformerStack(unittest.TestCase):
shape_m_before = m.shape
shape_z_before = z.shape
m, z, s = es(m, z, msa_mask, pair_mask)
m, z, s = es(
m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
)
self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before)
......@@ -127,6 +128,7 @@ class TestEvoformerStack(unittest.TestCase):
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4,
_mask_trans=False,
)
......@@ -171,7 +173,6 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
chunk_size=4,
inf=inf,
eps=eps,
).eval()
......@@ -199,7 +200,7 @@ class TestExtraMSAStack(unittest.TestCase):
shape_z_before = z.shape
z = es(m, z, msa_mask, pair_mask)
z = es(m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask)
self.assertTrue(z.shape == shape_z_before)
......@@ -212,12 +213,12 @@ class TestMSATransition(unittest.TestCase):
c_m = 7
n = 11
mt = MSATransition(c_m, n, chunk_size=4)
mt = MSATransition(c_m, n)
m = torch.rand((batch_size, s_t, n_r, c_m))
shape_before = m.shape
m = mt(m)
m = mt(m, chunk_size=4)
shape_after = m.shape
self.assertTrue(shape_before == shape_after)
......
......@@ -16,7 +16,7 @@ import torch
import numpy as np
import unittest
import openfold.features.data_transforms as data_transforms
import openfold.data.data_transforms as data_transforms
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -102,10 +102,12 @@ class TestFeats(unittest.TestCase):
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
out_repro = feats.atom37_to_torsion_angles(
torch.as_tensor(aatype).cuda(),
torch.as_tensor(all_atom_pos).cuda(),
torch.as_tensor(all_atom_mask).cuda(),
out_repro = data_transforms.atom37_to_torsion_angles()(
{
"aatype": torch.as_tensor(aatype).cuda(),
"all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
"all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
},
)
tasc = out_repro["torsion_angles_sin_cos"].cpu()
atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
......
......@@ -27,7 +27,7 @@ class TestImportWeights(unittest.TestCase):
c = model_config("model_1_ptm")
c.globals.blocks_per_ckpt = None
model = AlphaFold(c.model)
model = AlphaFold(c)
import_jax_weights_(
model,
......
......@@ -19,7 +19,7 @@ import numpy as np
import unittest
import ml_collections as mlc
from openfold.features import data_transforms
from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats
from openfold.utils.loss import (
......
......@@ -18,7 +18,7 @@ import torch.nn as nn
import numpy as np
import unittest
from openfold.config import model_config
from openfold.features.data_transforms import make_atom14_masks
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
......@@ -42,22 +42,21 @@ class TestModel(unittest.TestCase):
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config("model_1").model
c.no_cycles = 2
c.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
c = model_config("model_1")
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
batch = {}
tf = torch.randint(c.input_embedder.tf_dim - 1, size=(n_res,))
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim
tf, c.model.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.input_embedder.msa_dim))
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
......@@ -66,10 +65,11 @@ class TestModel(unittest.TestCase):
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(make_atom14_masks(batch))
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
batch = tensor_tree_map(add_recycling_dims, batch)
......@@ -94,7 +94,7 @@ class TestModel(unittest.TestCase):
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
out_gt = jax.jit(f.apply)(params, jax.random.PRNGKey(42), batch)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
out_gt = out_gt["structure_module"]["final_atom_positions"]
# atom37_to_atom14 doesn't like batches
......@@ -103,13 +103,19 @@ class TestModel(unittest.TestCase):
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
......
......@@ -41,13 +41,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
no_heads = 4
chunk_size = None
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size)
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads)
m = torch.rand((batch_size, n_seq, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = m.shape
m = mrapb(m, z)
m = mrapb(m, z=z, chunk_size=chunk_size)
shape_after = m.shape
self.assertTrue(shape_before == shape_after)
......@@ -91,8 +91,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model.evoformer.blocks[0]
.msa_att_row(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
z=torch.as_tensor(pair_act).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
......@@ -114,7 +115,7 @@ class TestMSAColumnAttention(unittest.TestCase):
x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape
x = msaca(x)
x = msaca(x, chunk_size=None)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
......@@ -155,7 +156,8 @@ class TestMSAColumnAttention(unittest.TestCase):
model.evoformer.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
......@@ -177,7 +179,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape
x = msagca(x)
x = msagca(x, chunk_size=None)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
......@@ -219,6 +221,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model.extra_msa_stack.stack.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
.cpu()
......
......@@ -38,11 +38,11 @@ class TestOuterProductMean(unittest.TestCase):
mask = torch.randint(
0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res)
)
m = opm(m, mask)
m = opm(m, mask=mask, chunk_size=None)
self.assertTrue(
m.shape
== (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
m.shape ==
(consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
)
@compare_utils.skip_unless_alphafold_installed()
......@@ -84,6 +84,7 @@ class TestOuterProductMean(unittest.TestCase):
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
......
......@@ -39,7 +39,7 @@ class TestPairTransition(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = z.shape
z = pt(z, mask)
z = pt(z, mask=mask, chunk_size=None)
shape_after = z.shape
self.assertTrue(shape_before == shape_after)
......@@ -79,6 +79,7 @@ class TestPairTransition(unittest.TestCase):
model.evoformer.blocks[0]
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
)
.cpu()
......
......@@ -16,7 +16,7 @@ import torch
import numpy as np
import unittest
from openfold.features.data_transforms import make_atom14_masks_np
from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -174,7 +174,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.01)
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
......
......@@ -42,13 +42,13 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
inf = 1e7
tpa = TemplatePointwiseAttention(
c_t, c_z, c, no_heads, chunk_size=4, inf=inf
c_t, c_z, c, no_heads, inf=inf
)
t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
z = torch.rand((batch_size, n_res, n_res, c_z))
z_update = tpa(t, z)
z_update = tpa(t, z, chunk_size=None)
self.assertTrue(z_update.shape == z.shape)
......@@ -79,7 +79,6 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
blocks_per_ckpt=None,
chunk_size=chunk_size,
inf=inf,
eps=eps,
)
......@@ -87,7 +86,7 @@ class TestTemplatePairStack(unittest.TestCase):
t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
shape_before = t.shape
t = tpe(t, mask)
t = tpe(t, mask, chunk_size=chunk_size)
shape_after = t.shape
self.assertTrue(shape_before == shape_after)
......@@ -136,6 +135,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_repro = model.template_pair_stack(
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
chunk_size=None,
_mask_trans=False,
).cpu()
......@@ -161,8 +161,8 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding"
......@@ -182,6 +182,7 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=None,
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment