Commit 49767099 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Bring tests up to speed

parent a6f56d16
...@@ -64,7 +64,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int) ...@@ -64,7 +64,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
...@@ -77,7 +76,6 @@ config = mlc.ConfigDict( ...@@ -77,7 +76,6 @@ config = mlc.ConfigDict(
{ {
"data": { "data": {
"common": { "common": {
"batch_modes": [("clamped", 0.9), ("unclamped", 0.1)],
"feat": { "feat": {
"aatype": [NUM_RES], "aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None], "all_atom_mask": [NUM_RES, None],
...@@ -93,7 +91,7 @@ config = mlc.ConfigDict( ...@@ -93,7 +91,7 @@ config = mlc.ConfigDict(
"backbone_affine_mask": [NUM_RES], "backbone_affine_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None], "backbone_affine_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES], "bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None], "chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None], "chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES], "extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES], "extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
...@@ -104,6 +102,7 @@ config = mlc.ConfigDict( ...@@ -104,6 +102,7 @@ config = mlc.ConfigDict(
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None], "msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES], "msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ], "msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None], "pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES], "pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES], "residue_index": [NUM_RES],
...@@ -149,8 +148,8 @@ config = mlc.ConfigDict( ...@@ -149,8 +148,8 @@ config = mlc.ConfigDict(
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024, "max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"num_recycle": num_recycle,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True, "resample_msa_in_recycling": True,
"template_features": [ "template_features": [
...@@ -167,9 +166,14 @@ config = mlc.ConfigDict( ...@@ -167,9 +166,14 @@ config = mlc.ConfigDict(
"seq_length", "seq_length",
"between_segment_residues", "between_segment_residues",
"deletion_matrix", "deletion_matrix",
"no_recycling_iters",
], ],
"use_templates": templates_enabled, "use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles, "use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [ "supervised_features": [
"all_atom_mask", "all_atom_mask",
"all_atom_positions", "all_atom_positions",
...@@ -212,6 +216,8 @@ config = mlc.ConfigDict( ...@@ -212,6 +216,8 @@ config = mlc.ConfigDict(
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"supervised": True, "supervised": True,
"clamp_prob": 0.9,
"subsample_recycling": True,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
...@@ -234,7 +240,6 @@ config = mlc.ConfigDict( ...@@ -234,7 +240,6 @@ config = mlc.ConfigDict(
"eps": eps, "eps": eps,
}, },
"model": { "model": {
"num_recycle": num_recycle,
"_mask_trans": False, "_mask_trans": False,
"input_embedder": { "input_embedder": {
"tf_dim": 22, "tf_dim": 22,
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
from typing import Optional, Sequence from typing import Optional, Sequence
import ml_collections as mlc import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
...@@ -216,31 +217,66 @@ class OpenFoldDataset(torch.utils.data.IterableDataset): ...@@ -216,31 +217,66 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"): def __init__(self, config, generator, stage="train"):
self.config = config self.config = config
batch_modes = config.common.batch_modes
batch_mode_names, batch_mode_probs = list(zip(*batch_modes))
self.batch_mode_names = batch_mode_names
self.batch_mode_probs = batch_mode_probs
self.generator = generator self.generator = generator
self.stage = stage self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
self._prep_batch_properties_probs()
self.batch_mode_probs_tensor = torch.tensor(self.batch_mode_probs) def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
self.feature_pipeline = feature_pipeline.FeaturePipeline(self.config) max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(self.config.supervised.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
def __call__(self, raw_prots): keys, probs = zip(*keyed_probs)
# We use torch.multinomial here rather than Categorical because the max_len = max([len(p) for p in probs])
# latter doesn't accept a generator for some reason padding = [[0.] * (max_len - len(p)) for p in probs]
batch_mode_idx = torch.multinomial(
self.batch_mode_probs_tensor, self.prop_keys = keys
1, self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def _add_batch_properties(self, raw_prots):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator generator=self.generator
).item() )
batch_mode_name = self.batch_mode_names[batch_mode_idx]
for i, key in enumerate(self.prop_keys):
sample = samples[i][0]
for prot in raw_prots:
prot[key] = np.array(sample, dtype=np.float32)
def __call__(self, raw_prots):
self._add_batch_properties(raw_prots)
processed_prots = [] processed_prots = []
for prot in raw_prots: for prot in raw_prots:
features = self.feature_pipeline.process_features( features = self.feature_pipeline.process_features(
prot, self.stage, batch_mode_name prot, self.stage
) )
processed_prots.append(features) processed_prots.append(features)
...@@ -265,6 +301,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -265,6 +301,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_mapping_path: Optional[str] = None, train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_mapping_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -286,6 +323,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -286,6 +323,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
self.batch_seed = batch_seed
if(self.train_data_dir is None and self.predict_data_dir is None): if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError( raise ValueError(
...@@ -309,7 +347,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -309,7 +347,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
def setup(self, stage): def setup(self, stage: Optional[str] = None):
if(stage is None):
stage = "train"
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -369,12 +410,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -369,12 +410,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
self.batch_collation_seed = torch.Generator().seed()
def _gen_batch_collator(self, stage): def _gen_batch_collator(self, stage):
""" We want each process to use the same batch collation seed """ """ We want each process to use the same batch collation seed """
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_collation_seed) if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
collate_fn = OpenFoldBatchCollator( collate_fn = OpenFoldBatchCollator(
self.config, generator, stage self.config, generator, stage
) )
...@@ -404,5 +444,5 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -404,5 +444,5 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset, self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size, batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval") collate_fn=self._gen_batch_collator("predict")
) )
...@@ -1095,7 +1095,6 @@ def random_crop_to_size( ...@@ -1095,7 +1095,6 @@ def random_crop_to_size(
shape_schema, shape_schema,
subsample_templates=False, subsample_templates=False,
seed=None, seed=None,
batch_mode="clamped",
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein["seq_length"] seq_length = protein["seq_length"]
...@@ -1133,13 +1132,11 @@ def random_crop_to_size( ...@@ -1133,13 +1132,11 @@ def random_crop_to_size(
num_templates_crop_size = num_templates num_templates_crop_size = num_templates
n = seq_length - num_res_crop_size n = seq_length - num_res_crop_size
if batch_mode == "clamped": if protein["use_clamped_fape"] == 1.:
right_anchor = n right_anchor = n
elif batch_mode == "unclamped": else:
x = _randint(0, n) x = _randint(0, n)
right_anchor = n - x right_anchor = n - x
else:
raise ValueError("Invalid batch mode")
num_res_crop_start = _randint(0, right_anchor) num_res_crop_start = _randint(0, right_anchor)
......
...@@ -64,7 +64,7 @@ def make_data_config( ...@@ -64,7 +64,7 @@ def make_data_config(
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
if cfg[mode].supervised: if cfg[mode].supervised:
feature_names += cfg.common.supervised_features feature_names += cfg.supervised.supervised_features
return cfg, feature_names return cfg, feature_names
...@@ -73,7 +73,6 @@ def np_example_to_features( ...@@ -73,7 +73,6 @@ def np_example_to_features(
np_example: FeatureDict, np_example: FeatureDict,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
batch_mode: str,
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example["seq_length"][0]) num_res = int(np_example["seq_length"][0])
...@@ -84,11 +83,6 @@ def np_example_to_features( ...@@ -84,11 +83,6 @@ def np_example_to_features(
"deletion_matrix_int" "deletion_matrix_int"
).astype(np.float32) ).astype(np.float32)
if batch_mode == "clamped":
np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
elif batch_mode == "unclamped":
np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
...@@ -97,7 +91,6 @@ def np_example_to_features( ...@@ -97,7 +91,6 @@ def np_example_to_features(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
cfg[mode], cfg[mode],
batch_mode=batch_mode,
) )
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
...@@ -116,11 +109,9 @@ class FeaturePipeline: ...@@ -116,11 +109,9 @@ class FeaturePipeline:
self, self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = "train", mode: str = "train",
batch_mode: str = "clamped",
) -> FeatureDict: ) -> FeatureDict:
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
config=self.config, config=self.config,
mode=mode, mode=mode,
batch_mode=batch_mode,
) )
...@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return transforms return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed): def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
...@@ -116,7 +116,6 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed): ...@@ -116,7 +116,6 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
mode_cfg.max_templates, mode_cfg.max_templates,
crop_feats, crop_feats,
mode_cfg.subsample_templates, mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=ensemble_seed, seed=ensemble_seed,
) )
) )
...@@ -137,9 +136,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed): ...@@ -137,9 +136,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
return transforms return transforms
def process_tensors_from_config( def process_tensors_from_config(tensors, common_cfg, mode_cfg):
tensors, common_cfg, mode_cfg, batch_mode="clamped"
):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed() ensemble_seed = torch.Generator().seed()
...@@ -150,7 +147,6 @@ def process_tensors_from_config( ...@@ -150,7 +147,6 @@ def process_tensors_from_config(
fns = ensembled_transform_fns( fns = ensembled_transform_fns(
common_cfg, common_cfg,
mode_cfg, mode_cfg,
batch_mode,
ensemble_seed, ensemble_seed,
) )
fn = compose(fns) fn = compose(fns)
...@@ -160,9 +156,11 @@ def process_tensors_from_config( ...@@ -160,9 +156,11 @@ def process_tensors_from_config(
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors) tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
num_recycling = tensors["no_recycling_iters"].item()
if common_cfg.resample_msa_in_recycling: if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step. # Separate batch per ensembling & recycling step.
num_ensemble *= common_cfg.num_recycle + 1 num_ensemble *= num_recycling + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn( tensors = map_fn(
......
...@@ -202,7 +202,7 @@ class AlphaFold(nn.Module): ...@@ -202,7 +202,7 @@ class AlphaFold(nn.Module):
) )
# Inject information from previous recycling iterations # Inject information from previous recycling iterations
if self.config.num_recycle > 0: if feats["no_recycling_iters"] > 0:
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
...@@ -236,7 +236,7 @@ class AlphaFold(nn.Module): ...@@ -236,7 +236,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + z_prev_emb z = z + z_prev_emb
# This can matter during inference when N_res is very large # Possibly prevents memory fragmentation
del m_1_prev_emb, z_prev_emb del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
...@@ -395,19 +395,21 @@ class AlphaFold(nn.Module): ...@@ -395,19 +395,21 @@ class AlphaFold(nn.Module):
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing() self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.num_recycle + 1): num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == self.config.num_recycle is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766 # Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing() self._enable_activation_checkpointing()
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
......
...@@ -258,11 +258,14 @@ class Attention(nn.Module): ...@@ -258,11 +258,14 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul( a = torch.matmul(q, k)
permute_final_dims(q, (1, 0, 2)), # [*, H, Q, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, K]
)
del q, k del q, k
...@@ -273,11 +276,11 @@ class Attention(nn.Module): ...@@ -273,11 +276,11 @@ class Attention(nn.Module):
a = a + b a = a + b
a = self.softmax(a) a = self.softmax(a)
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(a, v)
a,
permute_final_dims(v, (1, 0, 2)), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
......
#!/bin/bash #!/bin/bash
#CUDA_VISIBLE_DEVICES="5"
python3 -m unittest "$@" || \ python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies." echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
...@@ -60,7 +60,7 @@ _model = None ...@@ -60,7 +60,7 @@ _model = None
def get_global_pretrained_openfold(): def get_global_pretrained_openfold():
global _model global _model
if _model is None: if _model is None:
_model = AlphaFold(model_config("model_1_ptm").model) _model = AlphaFold(model_config("model_1_ptm"))
_model = _model.eval() _model = _model.eval()
if not os.path.exists(_param_path): if not os.path.exists(_param_path):
raise FileNotFoundError( raise FileNotFoundError(
......
...@@ -25,11 +25,17 @@ def random_template_feats(n_templ, n, batch_size=None): ...@@ -25,11 +25,17 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)), "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3), "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)), "template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_masks": np.random.randint( "template_all_atom_mask": np.random.randint(
0, 2, (*b, n_templ, n, 37) 0, 2, (*b, n_templ, n, 37)
), ),
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3) "template_all_atom_positions":
* 10, np.random.rand(*b, n_templ, n, 37, 3) * 10,
"template_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_alt_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_torsion_angles_mask":
np.random.rand(*b, n_templ, n, 7),
} }
batch = {k: v.astype(np.float32) for k, v in batch.items()} batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
......
...@@ -66,7 +66,6 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -66,7 +66,6 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, blocks_per_ckpt=None,
chunk_size=4,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -79,7 +78,9 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -79,7 +78,9 @@ class TestEvoformerStack(unittest.TestCase):
shape_m_before = m.shape shape_m_before = m.shape
shape_z_before = z.shape shape_z_before = z.shape
m, z, s = es(m, z, msa_mask, pair_mask) m, z, s = es(
m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
)
self.assertTrue(m.shape == shape_m_before) self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before) self.assertTrue(z.shape == shape_z_before)
...@@ -127,6 +128,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -127,6 +128,7 @@ class TestEvoformerStack(unittest.TestCase):
torch.as_tensor(activations["pair"]).cuda(), torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(), torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(), torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4,
_mask_trans=False, _mask_trans=False,
) )
...@@ -171,7 +173,6 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -171,7 +173,6 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, blocks_per_ckpt=None,
chunk_size=4,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -199,7 +200,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -199,7 +200,7 @@ class TestExtraMSAStack(unittest.TestCase):
shape_z_before = z.shape shape_z_before = z.shape
z = es(m, z, msa_mask, pair_mask) z = es(m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask)
self.assertTrue(z.shape == shape_z_before) self.assertTrue(z.shape == shape_z_before)
...@@ -212,12 +213,12 @@ class TestMSATransition(unittest.TestCase): ...@@ -212,12 +213,12 @@ class TestMSATransition(unittest.TestCase):
c_m = 7 c_m = 7
n = 11 n = 11
mt = MSATransition(c_m, n, chunk_size=4) mt = MSATransition(c_m, n)
m = torch.rand((batch_size, s_t, n_r, c_m)) m = torch.rand((batch_size, s_t, n_r, c_m))
shape_before = m.shape shape_before = m.shape
m = mt(m) m = mt(m, chunk_size=4)
shape_after = m.shape shape_after = m.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
import numpy as np import numpy as np
import unittest import unittest
import openfold.features.data_transforms as data_transforms import openfold.data.data_transforms as data_transforms
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -102,10 +102,12 @@ class TestFeats(unittest.TestCase): ...@@ -102,10 +102,12 @@ class TestFeats(unittest.TestCase):
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask) out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt) out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
out_repro = feats.atom37_to_torsion_angles( out_repro = data_transforms.atom37_to_torsion_angles()(
torch.as_tensor(aatype).cuda(), {
torch.as_tensor(all_atom_pos).cuda(), "aatype": torch.as_tensor(aatype).cuda(),
torch.as_tensor(all_atom_mask).cuda(), "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
"all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
},
) )
tasc = out_repro["torsion_angles_sin_cos"].cpu() tasc = out_repro["torsion_angles_sin_cos"].cpu()
atasc = out_repro["alt_torsion_angles_sin_cos"].cpu() atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
......
...@@ -27,7 +27,7 @@ class TestImportWeights(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestImportWeights(unittest.TestCase):
c = model_config("model_1_ptm") c = model_config("model_1_ptm")
c.globals.blocks_per_ckpt = None c.globals.blocks_per_ckpt = None
model = AlphaFold(c.model) model = AlphaFold(c)
import_jax_weights_( import_jax_weights_(
model, model,
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import unittest import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.features import data_transforms from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
......
...@@ -18,7 +18,7 @@ import torch.nn as nn ...@@ -18,7 +18,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import unittest import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.features.data_transforms import make_atom14_masks from openfold.data import data_transforms
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map from openfold.utils.tensor_utils import tree_map, tensor_tree_map
...@@ -42,22 +42,21 @@ class TestModel(unittest.TestCase): ...@@ -42,22 +42,21 @@ class TestModel(unittest.TestCase):
n_res = consts.n_res n_res = consts.n_res
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config("model_1").model c = model_config("model_1")
c.no_cycles = 2 c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c)
batch = {} batch = {}
tf = torch.randint(c.input_embedder.tf_dim - 1, size=(n_res,)) tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot( batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim tf, c.model.input_embedder.tf_dim
).float() ).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res) extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
...@@ -66,10 +65,11 @@ class TestModel(unittest.TestCase): ...@@ -66,10 +65,11 @@ class TestModel(unittest.TestCase):
low=0, high=2, size=(n_seq, n_res) low=0, high=2, size=(n_seq, n_res)
).float() ).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
batch = tensor_tree_map(add_recycling_dims, batch) batch = tensor_tree_map(add_recycling_dims, batch)
...@@ -94,7 +94,7 @@ class TestModel(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestModel(unittest.TestCase):
with open("tests/test_data/sample_feats.pickle", "rb") as fp: with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
out_gt = jax.jit(f.apply)(params, jax.random.PRNGKey(42), batch) out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
out_gt = out_gt["structure_module"]["final_atom_positions"] out_gt = out_gt["structure_module"]["final_atom_positions"]
# atom37_to_atom14 doesn't like batches # atom37_to_atom14 doesn't like batches
...@@ -103,13 +103,19 @@ class TestModel(unittest.TestCase): ...@@ -103,13 +103,19 @@ class TestModel(unittest.TestCase):
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch) out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long() batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long() batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long() batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end # Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
......
...@@ -41,13 +41,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -41,13 +41,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
no_heads = 4 no_heads = 4
chunk_size = None chunk_size = None
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size) mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads)
m = torch.rand((batch_size, n_seq, n_res, c_m)) m = torch.rand((batch_size, n_seq, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = m.shape shape_before = m.shape
m = mrapb(m, z) m = mrapb(m, z=z, chunk_size=chunk_size)
shape_after = m.shape shape_after = m.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -91,8 +91,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -91,8 +91,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model.evoformer.blocks[0] model.evoformer.blocks[0]
.msa_att_row( .msa_att_row(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(), z=torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(), chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() .cpu()
) )
...@@ -114,7 +115,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -114,7 +115,7 @@ class TestMSAColumnAttention(unittest.TestCase):
x = torch.rand((batch_size, n_seq, n_res, c_m)) x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape shape_before = x.shape
x = msaca(x) x = msaca(x, chunk_size=None)
shape_after = x.shape shape_after = x.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -155,7 +156,8 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -155,7 +156,8 @@ class TestMSAColumnAttention(unittest.TestCase):
model.evoformer.blocks[0] model.evoformer.blocks[0]
.msa_att_col( .msa_att_col(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(), chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() .cpu()
) )
...@@ -177,7 +179,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -177,7 +179,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
x = torch.rand((batch_size, n_seq, n_res, c_m)) x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape shape_before = x.shape
x = msagca(x) x = msagca(x, chunk_size=None)
shape_after = x.shape shape_after = x.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -219,6 +221,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -219,6 +221,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model.extra_msa_stack.stack.blocks[0] model.extra_msa_stack.stack.blocks[0]
.msa_att_col( .msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
) )
.cpu() .cpu()
......
...@@ -38,11 +38,11 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -38,11 +38,11 @@ class TestOuterProductMean(unittest.TestCase):
mask = torch.randint( mask = torch.randint(
0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res) 0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res)
) )
m = opm(m, mask) m = opm(m, mask=mask, chunk_size=None)
self.assertTrue( self.assertTrue(
m.shape m.shape ==
== (consts.batch_size, consts.n_res, consts.n_res, consts.c_z) (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
) )
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
...@@ -84,6 +84,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -84,6 +84,7 @@ class TestOuterProductMean(unittest.TestCase):
model.evoformer.blocks[0] model.evoformer.blocks[0]
.outer_product_mean( .outer_product_mean(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(), mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() .cpu()
......
...@@ -39,7 +39,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestPairTransition(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res)) mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = z.shape shape_before = z.shape
z = pt(z, mask) z = pt(z, mask=mask, chunk_size=None)
shape_after = z.shape shape_after = z.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -79,6 +79,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -79,6 +79,7 @@ class TestPairTransition(unittest.TestCase):
model.evoformer.blocks[0] model.evoformer.blocks[0]
.pair_transition( .pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
) )
.cpu() .cpu()
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
import numpy as np import numpy as np
import unittest import unittest
from openfold.features.data_transforms import make_atom14_masks_np from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -174,7 +174,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -174,7 +174,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile # The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to # We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data. # have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.01) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase): class TestBackboneUpdate(unittest.TestCase):
......
...@@ -42,13 +42,13 @@ class TestTemplatePointwiseAttention(unittest.TestCase): ...@@ -42,13 +42,13 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
inf = 1e7 inf = 1e7
tpa = TemplatePointwiseAttention( tpa = TemplatePointwiseAttention(
c_t, c_z, c, no_heads, chunk_size=4, inf=inf c_t, c_z, c, no_heads, inf=inf
) )
t = torch.rand((batch_size, n_seq, n_res, n_res, c_t)) t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
z_update = tpa(t, z) z_update = tpa(t, z, chunk_size=None)
self.assertTrue(z_update.shape == z.shape) self.assertTrue(z_update.shape == z.shape)
...@@ -79,7 +79,6 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -79,7 +79,6 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n=pt_inner_dim, pair_transition_n=pt_inner_dim,
dropout_rate=dropout, dropout_rate=dropout,
blocks_per_ckpt=None, blocks_per_ckpt=None,
chunk_size=chunk_size,
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
...@@ -87,7 +86,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -87,7 +86,7 @@ class TestTemplatePairStack(unittest.TestCase):
t = torch.rand((batch_size, n_templ, n_res, n_res, c_t)) t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res)) mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
shape_before = t.shape shape_before = t.shape
t = tpe(t, mask) t = tpe(t, mask, chunk_size=chunk_size)
shape_after = t.shape shape_after = t.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -136,6 +135,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -136,6 +135,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_repro = model.template_pair_stack( out_repro = model.template_pair_stack(
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
chunk_size=None,
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
...@@ -161,8 +161,8 @@ class Template(unittest.TestCase): ...@@ -161,8 +161,8 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res) batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)] # Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding" "alphafold/alphafold_iteration/evoformer/template_embedding"
...@@ -182,6 +182,7 @@ class Template(unittest.TestCase): ...@@ -182,6 +182,7 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=None,
) )
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment