"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "d5716ae751f2ea24bda45fad810750b5c7c72b29"
Unverified Commit e9db72d6 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

add metrics for training (#138)

* add metrics for training

* modify log loss format
parent 29f11deb
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from Bio.SVDSuperimposer import SVDSuperimposer
import torch
def _superimpose_np(reference, coords):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[N, 3] reference array
coords:
[N, 3] array
Returns:
A tuple of [N, 3] superimposed coords and the final RMSD.
"""
sup = SVDSuperimposer()
sup.set(reference, coords)
sup.run()
return sup.get_transformed(), sup.get_rms()
def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy()
coords_np = coords.detach().cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords, mask):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def select_unmasked_coords(coords, mask):
return torch.masked_select(
coords,
(mask > 0.)[..., None],
).reshape(-1, 3)
batch_dims = reference.shape[:-2]
flat_reference = reference.reshape((-1,) + reference.shape[-2:])
flat_coords = coords.reshape((-1,) + reference.shape[-2:])
flat_mask = mask.reshape((-1,) + mask.shape[-1:])
superimposed_list = []
rmsds = []
for r, c, m in zip(flat_reference, flat_coords, flat_mask):
r_unmasked_coords = select_unmasked_coords(r, m)
c_unmasked_coords = select_unmasked_coords(c, m)
superimposed, rmsd = _superimpose_single(
r_unmasked_coords,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count = 0
superimposed_full_size = torch.zeros_like(r)
for i, unmasked in enumerate(m):
if(unmasked):
superimposed_full_size[i] = superimposed[count]
count += 1
superimposed_list.append(superimposed_full_size)
rmsds.append(rmsd)
superimposed_stacked = torch.stack(superimposed_list, dim=0)
rmsds_stacked = torch.stack(rmsds, dim=0)
superimposed_reshaped = superimposed_stacked.reshape(
batch_dims + coords.shape[-2:]
)
rmsds_reshaped = rmsds_stacked.reshape(
batch_dims
)
return superimposed_reshaped, rmsds_reshaped
\ No newline at end of file
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fastfold.model.hub.loss import lddt_ca
from fastfold.common import residue_constants
from fastfold.utils.superimposition import superimpose
def drmsd(structure_1, structure_2, mask=None):
def prep_d(structure):
d = structure[..., :, None, :] - structure[..., None, :, :]
d = d ** 2
d = torch.sqrt(torch.sum(d, dim=-1))
return d
d1 = prep_d(structure_1)
d2 = prep_d(structure_2)
drmsd = d1 - d2
drmsd = drmsd ** 2
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return drmsd(structure_1, structure_2, mask)
def gdt(p1, p2, mask, cutoffs):
n = torch.sum(mask, dim=-1)
p1 = p1.float()
p2 = p2.float()
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
score = torch.sum((distances <= c) * mask, dim=-1) / n
score = torch.mean(score)
scores.append(score)
return sum(scores) / len(scores)
def gdt_ts(p1, p2, mask):
return gdt(p1, p2, mask, [1., 2., 4., 8.])
def gdt_ha(p1, p2, mask):
return gdt(p1, p2, mask, [0.5, 1., 2., 4.])
def compute_validation_metrics(
batch,
outputs,
superimposition_metrics=False,
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=1e-8,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
import os
import random
import torch
import numpy as np
......@@ -6,19 +7,34 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam
from tqdm import tqdm
from fastfold.config import model_config
from fastfold.model.hub import AlphaFold, AlphaFoldLRScheduler, AlphaFoldLoss
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.validation_utils import compute_validation_metrics
import logging
logging.disable(logging.WARNING)
#import logging
#logging.disable(logging.WARNING)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def log_loss(loss_breakdown, batch, outputs, train=True):
loss_info = ''
for loss_name, loss_value in loss_breakdown.items():
loss_info += (f' {loss_name}=' + "{:.3f}".format(loss_value))
with torch.no_grad():
other_metrics = compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for loss_name, loss_value in other_metrics.items():
loss_info += (f' {loss_name}=' + "{:.3f}".format(loss_value))
return loss_info
def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
......@@ -117,6 +133,26 @@ def main():
"--seed", type=int, default=42,
help="Random seed"
)
parser.add_argument(
"--max_epochs", type=int, default=10000,
help="The Max epochs of train"
)
parser.add_argument(
"--log_interval", type=int, default=1,
help="The interval steps of logging during training"
)
parser.add_argument(
"--log_path", type=str, default='train_log',
help="The path of log folder"
)
parser.add_argument(
"--save_ckpt_path", type=str, default=None,
help="The path where to save checkpoint, None means not save"
)
parser.add_argument(
"--save_ckpt_interval", type=int, default=1,
help="The interval epochs of save checkpoint"
)
args = parser.parse_args()
random.seed(args.seed)
......@@ -127,6 +163,7 @@ def main():
colossalai.launch_from_torch(config=dict(torch_ddp=dict(static_graph=True)))
disable_existing_loggers()
logger = get_dist_logger()
logger.log_to_file(args.log_path)
config = model_config(args.config_preset, train=True)
config.globals.inplace = False
......@@ -179,38 +216,40 @@ def main():
test_dataloader=test_dataloader,
)
for epoch in range(200):
logger.info('Start training.', ranks=[0])
for epoch in range(args.max_epochs):
engine.train()
if gpc.get_global_rank() == 0:
train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader:
for i, batch in enumerate(train_dataloader):
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
engine.zero_grad()
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
if gpc.get_global_rank() == 0:
train_dataloader.set_postfix(loss=float(loss))
if (i+1) % args.log_interval == 0:
logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*args.train_epoch_len+i+1},' +
f' Loss:{log_loss(loss_breakdown, batch, output)}', ranks=[0])
engine.zero_grad()
engine.backward(loss)
engine.step()
lr_scheduler.step()
if test_dataloader is not None:
engine.eval()
if gpc.get_global_rank() == 0:
train_dataloader = tqdm(train_dataloader)
for batch in test_dataloader:
for i, batch in enumerate(test_dataloader):
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
batch["use_clamped_fape"] = 0.
_, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
if gpc.get_global_rank() == 0:
train_dataloader.set_postfix(loss=float(loss))
logger.info(f'Validation, Step: {i+1}, \
Loss:{log_loss(loss_breakdown, batch, output, False)}', ranks=[0])
if (args.save_ckpt_path is not None) and ( (epoch+1) % args.save_ckpt_interval == 0):
torch.save(engine.model, os.path.join(args.save_ckpt_path, 'model.pth'))
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment