"demo/tutorials/tutorials_1.ipynb" did not exist on "a64900f3382e8e32155c47b8c5597022480b20ac"
Commit 13f8f163 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents a509a4c5 b5fa2ba3
Pipeline #235 failed with stages
in 0 seconds
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
import numpy as np
import unittest
from openfold.model.triangular_attention import TriangleAttention
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTriangularAttention(unittest.TestCase):
def test_shape(self):
c_z = consts.c_z
c = 12
no_heads = 4
starting = True
tan = TriangleAttention(c_z, c, no_heads, starting)
batch_size = consts.batch_size
n_res = consts.n_res
x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
x = tan(x, chunk_size=None)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
def _tri_att_compare(self, starting=False):
name = (
"triangle_attention_"
+ ("starting" if starting else "ending")
+ "_node"
)
def run_tri_att(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_att = alphafold.model.modules.TriangleAttention(
c_e.triangle_attention_starting_node
if starting
else c_e.triangle_attention_ending_node,
config.model.global_config,
name=name,
)
act = tri_att(pair_act=pair_act, pair_mask=pair_mask)
return act
f = hk.transform(run_tri_att)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z) * 100
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_att_start
if starting
else model.evoformer.blocks[0].core.tri_att_end
)
# To save memory, the full model transposes inputs outside of the
# triangle attention module. We adjust the module here.
module = copy.deepcopy(module)
module.starting = starting
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
chunk_size=None,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
self._tri_att_compare()
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_start_compare(self):
self._tri_att_compare(starting=True)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.triangular_multiplicative_update import *
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
c_z = consts.c_z
c = 11
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z
batch_size = consts.batch_size
x = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = x.shape
x = tm(x, mask)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
def _tri_mul_compare(self, incoming=False):
name = "triangle_multiplication_" + (
"incoming" if incoming else "outgoing"
)
def run_tri_mul(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_mul = alphafold.model.modules.TriangleMultiplication(
c_e.triangle_multiplication_incoming
if incoming
else c_e.triangle_multiplication_outgoing,
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
return act
f = hk.transform(run_tri_mul)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
self._tri_mul_compare()
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True)
def _tri_mul_inplace(self, incoming=False):
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=False,
).cpu()
# This has to come second because inference mode is in-place
out_inplace = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=2,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
def test_tri_mul_out_inference(self):
self._tri_mul_inplace()
def test_tri_mul_in_inference(self):
self._tri_mul_inplace(incoming=True)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import unittest
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, -1],
[0, 1, 0],
]
)
X_NEG_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, 1],
[0, -1, 0],
]
)
class TestUtils(unittest.TestCase):
def test_rigid_from_3_points_shape(self):
batch_size = 2
n_res = 5
x1 = torch.rand((batch_size, n_res, 3))
x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3))
r = Rigid.from_3_points(x1, x2, x3)
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2))
def test_rigid_from_4x4(self):
batch_size = 2
transf = [
[1, 0, 0, 1],
[0, 0, -1, 2],
[0, 1, 0, 3],
[0, 0, 0, 1],
]
transf = torch.tensor(transf)
true_rot = transf[:3, :3]
true_trans = transf[:3, 3]
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
r = Rigid.from_tensor_4x4(transf)
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_rigid_shape(self):
batch_size = 2
n = 5
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
def test_rigid_cat(self):
batch_size = 2
n = 5
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
transf_cat = Rigid.cat([transf, transf], dim=0)
transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1]
r = Rotation(rot_mats=X_90_ROT)
t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2)
self.assertTrue(
torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0)
t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
x = x.view(2, -1, 3) # [2, 10, 3]
pts = t[..., None].apply(x)
# All simple consequences of the two x-axis rotations
self.assertTrue(torch.all(pts[..., 0] == x[..., 0] + 1))
self.assertTrue(torch.all(pts[0, :, 1] == x[0, :, 2] * -1 + 1))
self.assertTrue(torch.all(pts[1, :, 1] == x[1, :, 2] + 1))
self.assertTrue(torch.all(pts[0, :, 2] == x[0, :, 1] + 1))
self.assertTrue(torch.all(pts[1, :, 2] == x[1, :, 1] * -1 + 1))
def test_quat_to_rot(self):
forty_five = math.pi / 4
quat = torch.tensor([math.cos(forty_five), math.sin(forty_five), 0, 0])
rot = quat_to_rot(quat)
eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked == unchunked))
def test_chunk_layer_dict(self):
class LinearDictLayer(torch.nn.Linear):
def forward(self, input):
out = super().forward(input)
return {"out": out, "inner": {"out": out + 1}}
x = torch.rand(2, 4, 5, 15)
l = LinearDictLayer(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
)
def test_chunk_slice_dict(self):
x = torch.rand(3, 4, 3, 5)
x_flat = x.view(-1, 5)
prod = 1
for d in x.shape[:-1]:
prod = prod * d
for i in range(prod):
for j in range(i + 1, prod + 1):
chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
import argparse
import os
import logging
import random
import numpy
import torch
from openfold.config import model_config
from openfold.data import feature_pipeline
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
from openfold.np import protein
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
relax_protein
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
numpy.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
with open(args.input_fasta) as fasta_file:
tags, sequences = parse_fasta(fasta_file.read())
if len(sequences) != 1:
raise ValueError("the threading script can only process a single sequence")
query_sequence = sequences[0]
query_tag = tags[0]
feature_dict = make_sequence_features_with_custom_template(
query_sequence,
args.input_mmcif,
args.template_id,
args.chain_id,
args.kalign_binary_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
output_name = f'{query_tag}_{args.config_preset}'
for model, output_directory in model_generator:
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: numpy.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args.subtract_plddt
)
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
parser.add_argument(
"--chain_id", type=str,
help="""The chain ID of the chain in the template to use"""
)
parser.add_argument(
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
\ No newline at end of file
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
from openfold.utils.import_weights import (
import_jax_weights_,
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
v,
on_step=False, on_epoch=True, logger=True
)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def _compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
config = model_config(
args.config_preset,
train=True,
low_prec=(str(args.precision) == "16")
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
callbacks = []
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val/lddt_ca",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="max",
check_finite=True,
strict=True,
)
callbacks.append(es)
if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"),
global_batch_size=global_batch_size,
)
callbacks.append(perf)
if(args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = []
if(args.wandb):
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path,
)
if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False)
else:
strategy = None
if(args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args(
args,
default_root_dir=args.output_dir,
strategy=strategy,
callbacks=callbacks,
logger=loggers,
)
if(args.resume_model_weights_only):
ckpt_path = None
else:
ckpt_path = args.resume_from_ckpt
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=ckpt_path,
)
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"output_dir", type=str,
help='''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser.add_argument(
"max_template_date", type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser.add_argument(
"--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
"--seed", type=int, default=None,
help="Random seed"
)
parser.add_argument(
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_every_epoch", action="store_true", default=False,
help="""Whether to checkpoint at the end of every training epoch"""
)
parser.add_argument(
"--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
"--min_delta", type=float, default=0,
help="""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser.add_argument(
"--patience", type=int, default=3,
help="Early stopping patience"
)
parser.add_argument(
"--resume_from_ckpt", type=str, default=None,
help="Path to a model checkpoint from which to restore training state"
)
parser.add_argument(
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument(
"--log_performance", type=bool_type, default=False,
help="Measure performance"
)
parser.add_argument(
"--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
)
parser.add_argument(
"--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
)
parser.add_argument(
"--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
)
parser.add_argument(
"--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
)
parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser.set_defaults(
num_sanity_val_steps=0,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
]
)
args = parser.parse_args()
if(args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment