#ifndef VALIDATE_OPENMM_FORCES_H_ #define VALIDATE_OPENMM_FORCES_H_ /* -------------------------------------------------------------------------- * * OpenMM * * -------------------------------------------------------------------------- * * This is part of the OpenMM molecular simulation toolkit originating from * * Simbios, the NIH National Center for Physics-Based Simulation of * * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * * Portions copyright (c) 2008 Stanford University and the Authors. * * Authors: Mark Friedrichs * * Contributors: * * * * This program is free software: you can redistribute it and/or modify * * it under the terms of the GNU Lesser General Public License as published * * by the Free Software Foundation, either version 3 of the License, or * * (at your option) any later version. * * * * This program is distributed in the hope that it will be useful, * * but WITHOUT ANY WARRANTY; without even the implied warranty of * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * * GNU Lesser General Public License for more details. * * * * You should have received a copy of the GNU Lesser General Public License * * along with this program. If not, see . * * -------------------------------------------------------------------------- */ #include "ValidateOpenMM.h" namespace OpenMM { typedef std::map< int, int > MapIntInt; typedef MapIntInt::iterator MapIntIntI; typedef MapIntInt::const_iterator MapIntIntCI; // Helper class for ValidateOpenMMForces class used to store force results and facitilitate comparisons of // resulting forces class ForceValidationResult { public: ForceValidationResult(const Context& context1, const Context& context2, StringUIntMap& forceNamesMap); ~ForceValidationResult(); /** * Get potential energy at specified platform index (0 || 1) * * @return potential energy for spercifed platform * * @throws OpenMMException if energyIndex is not 0 or 1 */ double getPotentialEnergy(int energyIndex) const; /** * Get array of forces at specified platform index (0 || 1) * * @return array of force norms * * @throws OpenMMException if forceIndex is not 0 or 1 */ std::vector getForceNorms(int forceIndex) const; /** * Get array of forces at platform index (0 || 1) * * @return force array * * @throws OpenMMException if forceIndex is not 0 or 1 */ std::vector getForces(int forceIndex) const; /** * Get maximum delta in force norm * * @param maxIndex return atom index of entry with maximum delta norm (optional) * * @return max delta in norm of forces */ double getMaxDeltaForceNorm(int* maxIndex = NULL) const; /** * Get maximum relative delta in force norm * * @param maxIndex return atom index of entry w/ maximum relative delta norm (optional) * * @return max relative delta in norm of forces */ double getMaxRelativeDeltaForceNorm(int* maxIndex = NULL) const; /** * Get maximum dot product between forces * * @param maxIndex return atom index of entry w/ maximum dot product between forces (optional) * * @return max dot product between forces */ double getMaxDotProduct(int* maxIndex = NULL) const; /** * Get name of force associated w/ computed results * * @return force name(s); if more than one force active in computation, * then names are concatenated and separated by '::' (e.g., 'NB_FORCE::GBSA_OBC_FORCE') */ std::string getForceName() const; /** * Get platform name * * @param index index of platform (0 or 1) * * @return platform name * * @throws OpenMMException if index is not 0 or 1 */ std::string getPlatformName(int index) const; /** * Register index of two entries that differ by a specified tolerance * * @param index inconsistent index * */ void registerInconsistentForceIndex(int index, int value = 1); /** * Clear list of entries that differ by a specified tolerance * */ void clearInconsistentForceIndexList(); /** * Get list of entries that differ by a specified tolerance * */ void getInconsistentForceIndexList(std::vector& inconsistentIndices) const; /** * Get number of entries in inconsistent index list * */ int getNumberOfInconsistentForceEntries() const; /** * Return true if nans were detected * * @return true if nans were detected */ int nansDetected() const; /** * Determine if force norms are valid * * @param tolerance tolerance */ void compareForceNorms(double tolerance); /** * Determine if forces are valid * * @param tolerance tolerance */ void compareForces(double tolerance); private: // computed potential energies and forces fror two platforms double _potentialEnergies[2]; std::vector _forces[2]; // platform and force names std::string _platforms[2]; std::vector _forceNames; // force norms and stat entries std::vector _norms[2]; std::vector _normStatVectors[2]; // map of indicies w/ inconsistent force entries std::map _inconsistentForceIndicies; // if set, then nans detected int _nansDetected; /** * Calculate norms of vectors * */ void _calculateNorms(); /** * Calculate norms of specified vector * */ void _calculateNormOfForceVector(int forceIndex); // stat indices static const int STAT_AVG = 0; static const int STAT_STD = 1; static const int STAT_MIN = 2; static const int STAT_ID1 = 3; static const int STAT_MAX = 4; static const int STAT_ID2 = 5; static const int STAT_CNT = 6; static const int STAT_END = 7; /** * Find vector stats * */ void _findStatsForDouble(const std::vector& array, std::vector& statVector) const; }; // Class used to compare forces/potential energies on two platforms class ValidateOpenMMForces : public ValidateOpenMM { public: OPENMM_VALIDATE_EXPORT ValidateOpenMMForces(); OPENMM_VALIDATE_EXPORT ~ValidateOpenMMForces(); /** * Validate force/energy by comparing the results between the forces/energies computed on user-provided context platform * with Reference platform * * @param context context reference * @param summaryString output summary string of results of comparison (optional) * * @return number of inconsistent entries */ int OPENMM_VALIDATE_EXPORT compareWithReferencePlatform(Context& context, std::string* summaryString = NULL); /** * Validate force/energy by comparing the results between the forces/energies computed on two different platforms * * @param context context reference * @param compareForces indices of force to be tested * @param platform1 first platform to compute forces * @param platform2 second platform to compute forces * * @return ForceValidationResult reference containing results of force/energy computations * on the two input platforms */ ForceValidationResult* compareForce(Context& context, std::vector& compareForces, Platform& platform1, Platform& platform2) const; /** * Compare individual forces by comparing calculations across two platforms (platform associated w/ input context and * comparisonPlatform) * * @param context context reference * @param platform comparsion platform reference * @param forceValidationResults output vector of ForceValidationResult ptrs (user is responsible for deleting * individual ForceValidationResult objects) */ void compareOpenMMForces(Context& context, Platform& comparisonPlatform, std::vector& forceValidationResults) const; /** * Determine if results are consistent * * @param forceValidationResults vector of ForceValidationResult ptrs to check if forces are consistent */ void checkForInconsistentForceEntries(std::vector& forceValidationResults) const; /** * Get total number of force entries that are inconsistent * * @param forceValidationResults vector of ForceValidationResult ptrs to check if forces are consistent */ int getTotalNumberOfInconsistentForceEntries(std::vector& forceValidationResults) const; /** * Get summary string of results * * @param forceValidationResults vector of ForceValidationResult ptrs */ std::string getSummary(std::vector& forceValidationResults) const; /** * Set force tolerance * * @param tolerance force tolerance */ void setForceTolerance(double tolerance); /** * Get force tolerance * * @return force tolerance */ double getForceTolerance() const; /* * Get force tolerance for specified force * * @param forceName name of force * * @return force tolerance * * */ double getForceTolerance(const std::string& forceName) const; /* * Get max errors to print in summary string * * @return max errors to print * * */ int getMaxErrorsToPrint() const; /* * Set max errors to print in summary string * * @param maxErrorsToPrint max errors to print * * */ void setMaxErrorsToPrint(int maxErrorsToPrint); /* * Return true if force is not to be validated (Andersen thermostat, CM motion remover, ...) * * @param forceName force name * * @return true if force is not currently validated **/ int isExcludedForce(std::string forceName) const; private: // initialize class entries void _initialize(); /* * Format output line * * @param tab tab * @param description description * @param value value * * @return string containing contents * * */ std::string _getLine(const std::string& tab, const std::string& description, const std::string& value) const; std::vector _forceValidationResults; // max errors to print int _maxErrorsToPrint; // tolerence double _forceTolerance; // map of force tolerances to type (name) StringDoubleMap _forceTolerances; // forces to be excluded from validation StringIntMap _forcesToBeExcluded; }; } // namespace OpenMM #endif /*VALIDATE_OPENMM_FORCES_H_*/