Commit dba44612 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Resolve merge conflicts

parents 4bd1b4d5 576174f0
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template<typename T>
__global__ void attn_softmax_inplace_(
T *input,
long long rows, int cols
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_input = input + row_offset * cols;
T *row_output = row_input;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
int idx = lane_id * cols_per_thread + i;
buf[i] = static_cast<float>(row_input[idx]);
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] =
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
attn_softmax_inplace_<float><<<grid, block>>>(
(float *)input.data_ptr(),
rows, cols
);
}
else {
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(),
rows, cols
);
}
}
template<typename T>
__global__ void attn_softmax_inplace_grad_(
T *output,
T *d_ov,
T *values,
long long rows,
int cols_output,
int cols_values
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols_output + 31) / 32;
int cols_this_thread = cols_per_thread;
int rows_values = cols_output;
// values are set to the beginning of the current
// rows_values x cols_values leaf matrix
long long value_row_offset = row_offset - row_offset % rows_values;
int last_y = (cols_output / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols_output - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_output = output + row_offset * cols_output;
T *row_d_ov = d_ov + row_offset * cols_values;
T *row_values = values + value_row_offset * cols_values;
float thread_max = -1 * CUDART_INF_F;
// Compute a chunk of the output gradient on the fly
int value_row_idx = 0;
int value_idx = 0;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
T sum = 0.;
#pragma unroll
for (int j = 0; j < cols_values; j++) {
value_row_idx = ((lane_id * cols_per_thread) + i);
value_idx = value_row_idx * cols_values + j;
sum += row_d_ov[j] * row_values[value_idx];
}
dy_buf[i] = static_cast<float>(sum);
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(
(dy_buf[i] - warp_sum) * y_buf[i]
);
}
}
}
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
) {
CHECK_INPUT(output);
CHECK_INPUT(d_ov);
CHECK_INPUT(values);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
attn_softmax_inplace_grad_<float><<<grid, block>>>(
(float *)output.data_ptr(),
(float *)d_ov.data_ptr(),
(float *)values.data_ptr(),
rows, cols_output, cols_values
);
} else {
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)d_ov.data_ptr(),
(at::BFloat16 *)values.data_ptr(),
rows, cols_output, cols_values
);
}
}
......@@ -334,10 +334,12 @@ def supervised_chi_loss(
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
......@@ -1526,39 +1528,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss
def compute_drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
structure_1 = structure_1 * mask[..., None]
structure_2 = structure_2 * mask[..., None]
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
drmsd = d1 - d2
drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1)))
drmsd = torch.sqrt(drmsd)
return drmsd
def compute_drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return compute_drmsd(structure_1, structure_2, mask)
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
......@@ -1627,6 +1596,10 @@ class AlphaFoldLoss(nn.Module):
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
......
......@@ -17,7 +17,7 @@ class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
base_lr: float = 0.,
max_lr: float = 0.001,
warmup_no_steps: int = 1000,
start_decay_after_n_steps: int = 10000,
start_decay_after_n_steps: int = 50000,
decay_every_n_steps: int = 50000,
decay_factor: float = 0.95,
):
......
......@@ -42,7 +42,7 @@ def _superimpose_single(reference, coords):
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords):
def superimpose(reference, coords, mask):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
......@@ -51,18 +51,42 @@ def superimpose(reference, coords):
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def select_unmasked_coords(coords, mask):
return torch.masked_select(
coords,
(mask > 0.)[..., None],
).reshape(-1, 3)
batch_dims = reference.shape[:-2]
flat_reference = reference.reshape((-1,) + reference.shape[-2:])
flat_coords = coords.reshape((-1,) + reference.shape[-2:])
flat_mask = mask.reshape((-1,) + mask.shape[-1:])
superimposed_list = []
rmsds = []
for r, c in zip(flat_reference, flat_coords):
superimposed, rmsd = _superimpose_single(r, c)
superimposed_list.append(superimposed)
rmsds.append(rmsd)
for r, c, m in zip(flat_reference, flat_coords, flat_mask):
r_unmasked_coords = select_unmasked_coords(r, m)
c_unmasked_coords = select_unmasked_coords(c, m)
superimposed, rmsd = _superimpose_single(
r_unmasked_coords,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count = 0
superimposed_full_size = torch.zeros_like(r)
for i, unmasked in enumerate(m):
if(unmasked):
superimposed_full_size[i] = superimposed[count]
count += 1
superimposed_list.append(superimposed_full_size)
rmsds.append(rmsd)
superimposed_stacked = torch.stack(superimposed_list, dim=0)
rmsds_stacked = torch.stack(rmsds, dim=0)
......
......@@ -14,16 +14,47 @@
import torch
def drmsd(structure_1, structure_2, mask=None):
def prep_d(structure):
d = structure[..., :, None, :] - structure[..., None, :, :]
d = d ** 2
d = torch.sqrt(torch.sum(d, dim=-1))
return d
d1 = prep_d(structure_1)
d2 = prep_d(structure_2)
drmsd = d1 - d2
drmsd = drmsd ** 2
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return drmsd(structure_1, structure_2, mask)
def gdt(p1, p2, mask, cutoffs):
n = torch.sum(mask, dim=-1)
p1 = p1.float()
p2 = p2.float()
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
score = torch.sum((distances <= c) * mask, dim=-1) / n
score = torch.mean(score)
scores.append(score)
return sum(scores) / len(scores)
......
......@@ -234,7 +234,7 @@ def main(args):
# Relax the prediction.
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
......@@ -249,6 +249,13 @@ def main(args):
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -283,6 +290,10 @@ if __name__ == "__main__":
automatically according to the model name from
openfold/resources/params"""
)
parser.add_argument(
"--save_outputs", type=bool, default=False,
help="Whether to save all model outputs, including embeddings, etc."
)
parser.add_argument(
"--cpus", type=int, default=4,
help="""Number of CPUs with which to run alignment tools"""
......
......@@ -27,7 +27,7 @@ def parse_file(f, args):
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
chain_ids, seqs = mmcif.chain_to_seqres.items()
chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items()))
local_data["chain_ids"] = chain_ids
local_data["seqs"] = seqs
local_data["no_chains"] = len(chain_ids)
......
......@@ -12,8 +12,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import find_packages
from setuptools import setup
import os
from setuptools import setup, Extension, find_packages
import subprocess
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
version_dependent_macros = [
'-DVERSION_GE_1_1',
'-DVERSION_GE_1_3',
'-DVERSION_GE_1_5',
]
extra_cuda_flags = [
'-std=c++14',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
extra_cuda_flags += cc_flag
setup(
name='openfold',
......@@ -25,7 +63,32 @@ setup(
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={"": ["resources/stereo_chemical_props.txt"]},
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=[CUDAExtension(
name="attn_core_inplace_cuda",
sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp",
"openfold/utils/kernel/csrc/softmax_cuda_kernel.cu",
],
include_dirs=[
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'openfold/utils/kernel/csrc/'
)
],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc': (
['-O3', '--use_fast_math'] +
version_dependent_macros +
extra_cuda_flags
),
}
)],
cmdclass={'build_ext': BuildExtension},
install_requires=[
'torch',
'deepspeed',
......
......@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
),
]
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import unittest
from openfold.model.primitives import _attention
from openfold.utils.kernel.attention_core import attention_core
from tests.config import consts
class TestAttentionCore(unittest.TestCase):
def test_attention_core_forward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
k = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
v = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
out_repro = attention_core(q, k, v, mask_bias, None)
out_gt = _attention(q, k, v, [mask_bias])
self.assertTrue(torch.max(torch.abs(out_repro - out_gt)) < consts.eps)
def test_attention_core_backward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
k = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
v = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
def clone(t):
t = t.clone()
if(t.requires_grad):
t.retain_grad()
return t
q_repro = clone(q)
k_repro = clone(k)
v_repro = clone(v)
out_repro = attention_core(
q_repro, k_repro, v_repro, mask_bias, None
)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
k_gt = clone(k)
v_gt = clone(v)
out_gt = _attention(
q_gt, k_gt, v_gt, [mask_bias]
)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
pairs = zip([q_repro, k_repro, v_repro], [q_gt, k_gt, v_gt])
for t_repro, t_gt in pairs:
self.assertTrue(
torch.max(torch.abs(t_repro.grad - t_gt.grad)) < consts.eps
)
if __name__ == '__main__':
unittest.main()
......@@ -42,6 +42,7 @@ from openfold.utils.loss import (
backbone_loss,
sidechain_loss,
tm_loss,
compute_plddt,
)
from openfold.utils.tensor_utils import (
tree_map,
......@@ -226,6 +227,21 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
@compare_utils.skip_unless_alphafold_installed()
def test_compute_plddt_compare(self):
n_res = consts.n_res
logits = np.random.rand(n_res, 50)
out_gt = alphafold.common.confidence.compute_plddt(logits)
out_gt = torch.tensor(out_gt)
logits_t = torch.tensor(logits)
out_repro = compute_plddt(logits_t)
self.assertTrue(
torch.max(torch.abs(out_gt - out_repro)) < consts.eps
)
def test_find_structural_violations(self):
n = consts.n_res
......@@ -655,7 +671,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_backbone_loss(self):
def test_backbone_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
......
......@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0]
model.evoformer.blocks[0].core
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
......
......@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0]
model.evoformer.blocks[0].core
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
......
......@@ -8,6 +8,7 @@ import os
#os.environ["NODE_RANK"]="0"
import random
import sys
import time
import numpy as np
......@@ -27,16 +28,18 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
......@@ -72,12 +75,12 @@ class OpenFoldWrapper(pl.LightningModule):
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"train/loss_epoch",
loss_breakdown["loss"],
on_step=False, on_epoch=True, logger=True,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
......@@ -116,16 +119,14 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
self.cached_weights = self.model.state_dict()
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
......@@ -171,20 +172,20 @@ class OpenFoldWrapper(pl.LightningModule):
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = compute_drmsd(
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, _ = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
......@@ -193,6 +194,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
......@@ -203,11 +205,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps: float = 1e-5,
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
return torch.optim.Adam(
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
......@@ -232,7 +246,7 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
# TorchScript components of the model
if(args.script_modules):
script_preset_(model_module)
......@@ -251,6 +265,8 @@ def main(args):
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
)
callbacks.append(mc)
......@@ -300,7 +316,12 @@ def main(args):
strategy = DDPPlugin(find_unused_parameters=False)
else:
strategy = None
if(args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args(
args,
default_root_dir=args.output_dir,
......@@ -487,9 +508,15 @@ if __name__ == "__main__":
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser.add_argument(
"--_distillation_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment