Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
......@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase):
q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
n_seq=consts.n_seq,
n=2**12,
n=2 ** 12,
no_heads=no_heads,
c_hidden=c_hidden)
......
......@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask,
)
from openfold.model.structure_module import (
StructureModule,
StructureModuleTransition,
BackboneUpdate,
AngleResnet,
InvariantPointAttention,
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self):
batch_size = consts.batch_size
n = consts.n_res
......@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor,
ar_epsilon,
inf,
is_multimer=consts.is_multimer
)
s = torch.rand((batch_size, n, c_s))
......@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase):
out = sm({"single": s, "pair": z}, f)
if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
sm = self.am_fold.StructureModule(c_sm, c_global)
representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items()
}
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
......@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
c_m = 13
c_z = 17
......@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
)
shape_before = s.shape
......@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention(
ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module,
config.model.global_config,
)
if consts.is_multimer:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine,
rigid=affine
)
else:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn
f = hk.transform(run_ipa)
......@@ -238,12 +287,19 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
......@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase):
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention,
TemplatePairStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
......@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
batch_size = consts.batch_size
c_t = consts.c_t
......@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25
n_templ = consts.n_templ
n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None
chunk_size = 4
inf = 1e7
......@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads,
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -96,7 +114,35 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
if consts.is_multimer:
safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
template_iteration = self.am_modules.TemplateEmbeddingIteration(
c_ee.template.template_pair_stack,
config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
......@@ -115,6 +161,12 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res)
).astype(np.float32)
if consts.is_multimer:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
......@@ -132,25 +184,43 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack(
out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None,
_mask_trans=False,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def test_template_embedding(pair, batch, mask_2d):
def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config,
)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act
......@@ -162,6 +232,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
......@@ -169,7 +247,7 @@ class Template(unittest.TestCase):
)
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -177,17 +255,36 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
_mask_trans=False,
use_lma=False,
inplace_safe=False
)
else:
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
mask_trans=False,
use_lma=False,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_att_start
model.evoformer.blocks[0].pair_stack.tri_att_start
if starting
else model.evoformer.blocks[0].core.tri_att_end
else model.evoformer.blocks[0].pair_stack.tri_att_end
)
# To save memory, the full model transposes inputs outside of the
......@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import torch
import re
import numpy as np
import unittest
from openfold.model.triangular_multiplicative_update import *
......@@ -31,6 +32,12 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z
c = 11
if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
tm = FusedTriangleMultiplicationOutgoing(
c_z,
c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
......@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
act = tri_mul(pair_act, pair_mask)
return act
f = hk.transform(run_tri_mul)
......@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
......@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,
DummyDataLoader,
)
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
......@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import (
)
from openfold.utils.import_weights import (
import_jax_weights_,
import_openfold_weights_
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
......@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.is_multimer = self.config.globals.is_multimer
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
......@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
......@@ -127,12 +132,20 @@ class OpenFoldWrapper(pl.LightningModule):
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0.
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss and other metrics
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
......@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)
return {
......@@ -265,8 +279,8 @@ def main(args):
train=True,
low_prec=(str(args.precision) == "16")
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
......@@ -281,7 +295,7 @@ def main(args):
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
......@@ -291,7 +305,13 @@ def main(args):
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
......@@ -416,6 +436,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="Path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs."
......@@ -436,6 +460,10 @@ if __name__ == "__main__":
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during validation"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment