Unverified Commit 1feca912 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #2 from aqlaboratory/marta-sd/nan-loss

Fix NaN loss bug + add development env for Nvidians
parents 56f0a8a0 709b144c
lightning_logs
**/__pycache__
lib/conda
**/*egg-info
ARG IMAGE_NAME=nvcr.io/nvidia/pytorch:21.06-py3
FROM ${IMAGE_NAME} AS base
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y aria2
COPY . /workspace/openfold
WORKDIR /workspace/openfold
RUN pip install -r requirements_minimal.txt
RUN python setup.py install
# TODO add all dependencies needed for inference
RUN wget -q -P openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt && \
mkdir -p tests/test_data/alphafold/common && \
ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
RUN scripts/download_alphafold_params.sh openfold/resources
RUN gunzip tests/test_data/sample_feats.pickle.gz
...@@ -429,7 +429,7 @@ def _is_set(data: str) -> bool: ...@@ -429,7 +429,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
...@@ -474,4 +474,9 @@ def get_atom_coords( ...@@ -474,4 +474,9 @@ def get_atom_coords(
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
if zero_center:
binary_mask = all_atom_mask.astype(np.bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec
return all_atom_positions, all_atom_mask return all_atom_positions, all_atom_mask
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import operator
import time
import dllogger as logger
import numpy as np
import torch.cuda.profiler as profiler
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
from pytorch_lightning import Callback
def is_main_process():
return int(os.getenv("LOCAL_RANK", "0")) == 0
class PerformanceLoggingCallback(Callback):
def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False):
logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)])
self.warmup_steps = warmup_steps
self.global_batch_size = global_batch_size
self.step = 0
self.profile = profile
self.timestamps = []
def do_step(self):
self.step += 1
if self.profile and self.step == self.warmup_steps:
profiler.start()
if self.step > self.warmup_steps:
self.timestamps.append(time.time())
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.do_step()
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.do_step()
def process_performance_stats(self, deltas):
def _round3(val):
return round(val, 3)
throughput_imgps = _round3(self.global_batch_size / np.mean(deltas))
timestamps_ms = 1000 * deltas
stats = {
f"throughput": throughput_imgps,
f"latency_mean": _round3(timestamps_ms.mean()),
}
for level in [90, 95, 99]:
stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))})
return stats
def _log(self):
if is_main_process():
diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1]))
deltas = np.array(diffs)
stats = self.process_performance_stats(deltas)
logger.log(step=(), data=stats)
logger.flush()
def on_train_end(self, trainer, pl_module):
if self.profile:
profiler.stop()
self._log()
def on_epoch_end(self, trainer, pl_module):
self._log()
...@@ -1495,6 +1495,9 @@ class AlphaFoldLoss(nn.Module): ...@@ -1495,6 +1495,9 @@ class AlphaFoldLoss(nn.Module):
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{k} loss is NaN. Skipping example...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
return cum_loss return cum_loss
biopython==1.79
deepspeed==0.5.3
dm-tree==0.1.6
ml-collections==0.1.0
ninja==1.10.2
tensorboardX==1.8
triton==1.0.0
git+https://github.com/PyTorchLightning/pytorch-lightning.git@6d79184ec50d9f80448013cc2c01b5b744a70b44
...@@ -25,6 +25,9 @@ conda update -qy conda \ ...@@ -25,6 +25,9 @@ conda update -qy conda \
conda install -c bioconda aria2 conda install -c bioconda aria2
conda install -y -c bioconda hmmer==3.3.2 hhsuite==3.3.0 kalign2==2.04 conda install -y -c bioconda hmmer==3.3.2 hhsuite==3.3.0 kalign2==2.04
pip install nvidia-pyindex
pip install nvidia-dllogger
# Install DeepMind's OpenMM patch # Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \ pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \
......
...@@ -34,6 +34,8 @@ from scripts.zero_to_fp32 import ( ...@@ -34,6 +34,8 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint get_fp32_state_dict_from_zero_checkpoint
) )
from openfold.utils.logger import PerformanceLoggingCallback
class OpenFoldWrapper(pl.LightningModule): class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config): def __init__(self, config):
...@@ -147,15 +149,25 @@ def main(args): ...@@ -147,15 +149,25 @@ def main(args):
strict=True, strict=True,
) )
callbacks.append(es) callbacks.append(es)
if args.log_performance:
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"),
global_batch_size=global_batch_size,
)
callbacks.append(perf)
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(config=args.deepspeed_config_path) strategy = DeepSpeedPlugin(config=args.deepspeed_config_path)
else: elif args.gpus > 1 or args.num_nodes > 1:
strategy = "ddp" strategy = "ddp"
else:
strategy = None
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
strategy=strategy, strategy=strategy,
callbacks=callbacks,
) )
if(args.resume_model_weights_only): if(args.resume_model_weights_only):
...@@ -174,6 +186,16 @@ def main(args): ...@@ -174,6 +186,16 @@ def main(args):
) )
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -234,7 +256,7 @@ if __name__ == "__main__": ...@@ -234,7 +256,7 @@ if __name__ == "__main__":
files.""" files."""
) )
parser.add_argument( parser.add_argument(
"--use_small_bfd", type=bool, default=False, "--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database" help="Whether to use a reduced version of the BFD database"
) )
parser.add_argument( parser.add_argument(
...@@ -246,12 +268,12 @@ if __name__ == "__main__": ...@@ -246,12 +268,12 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
) )
parser.add_argument( parser.add_argument(
"--checkpoint_best_val", type=bool, default=True, "--checkpoint_best_val", type=bool_type, default=True,
help="""Whether to save the model parameters that perform best during help="""Whether to save the model parameters that perform best during
validation""" validation"""
) )
parser.add_argument( parser.add_argument(
"--early_stopping", type=bool, default=False, "--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease" help="Whether to stop training when validation loss fails to decrease"
) )
parser.add_argument( parser.add_argument(
...@@ -268,9 +290,13 @@ if __name__ == "__main__": ...@@ -268,9 +290,13 @@ if __name__ == "__main__":
help="Path to a model checkpoint from which to restore training state" help="Path to a model checkpoint from which to restore training state"
) )
parser.add_argument( parser.add_argument(
"--resume_model_weights_only", type=bool, default=False, "--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state" help="Whether to load just model weights as opposed to training state"
) )
parser.add_argument(
"--log_performance", action='store_true',
help="Measure performance"
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment