"docs/vscode:/vscode.git/clone" did not exist on "ee3a11045686f6c51e4663f09b06f4d3485f7a81"
ValidateOpenMMForces.h 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifndef VALIDATE_OPENMM_FORCES_H_
#define VALIDATE_OPENMM_FORCES_H_

/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2008 Stanford University and the Authors.           *
 * Authors: Mark Friedrichs                                                   *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "ValidateOpenMM.h"

namespace OpenMM {

typedef std::map< int, int > MapIntInt;
typedef MapIntInt::iterator MapIntIntI;
typedef MapIntInt::const_iterator MapIntIntCI;

// Helper class for ValidateOpenMMForces class used to store force results and facitilitate comparisons of
// resulting forces

class ForceValidationResult {
public:

44
    ForceValidationResult(const Context& context1, const Context& context2, StringUIntMap& forceNamesMap);
45
46
47
48
49
50
51
52
53
    ~ForceValidationResult();

    /**
     * Get potential energy at specified platform index (0 || 1)
     * 
     * @return  potential energy for spercifed platform
     *
     * @throws OpenMMException if energyIndex is not 0 or 1
     */
54
    double getPotentialEnergy(int energyIndex) const;
55
56
57
58
59
60
61
62

    /**
     * Get array of forces at specified platform index (0 || 1)
     * 
     * @return array of force norms
     *
     * @throws OpenMMException if forceIndex is not 0 or 1
     */
63
    std::vector<double> getForceNorms(int forceIndex) const;
64
65
66
67
68
69
70
71

    /**
     * Get array of forces at platform index (0 || 1)
     * 
     * @return force array
     *
     * @throws OpenMMException if forceIndex is not 0 or 1
     */
72
    std::vector<Vec3> getForces(int forceIndex) const;
73
74
75
76
77
78
79
80

    /**
     * Get maximum delta in force norm
     * 
     * @param maxIndex  return atom index of entry with maximum delta norm (optional)
     *
     * @return max delta in norm of forces
     */
81
    double getMaxDeltaForceNorm(int* maxIndex = NULL) const;
82
83
84
85
86
87
88
89

    /**
     * Get maximum relative delta in force norm
     * 
     * @param maxIndex  return atom index of entry w/ maximum relative delta norm (optional)
     *
     * @return max relative delta in norm of forces
     */
90
    double getMaxRelativeDeltaForceNorm(int* maxIndex = NULL) const;
91
92
93
94
95
96
97
98

    /**
     * Get maximum dot product between forces
     * 
     * @param maxIndex  return atom index of entry w/ maximum dot product between forces (optional)
     *
     * @return max dot product between forces
     */
99
    double getMaxDotProduct(int* maxIndex = NULL) const;
100
101
102
103
104
105
106

    /**
     * Get name of force associated w/ computed results
     * 
     * @return force name(s); if more than one force active in computation, 
     * then names are concatenated and separated by '::' (e.g., 'NB_FORCE::GBSA_OBC_FORCE')
     */
107
    std::string getForceName() const;
108
109
110
111
112
113
114
115
116
117

    /**
     * Get platform name
     * 
     * @param index index of platform (0 or 1)
     *
     * @return platform name
     *
     * @throws OpenMMException if index is not 0 or 1
     */
118
    std::string getPlatformName(int index) const;
119
120
121
122
123
124
125

    /**
     * Register index of two entries that differ by a specified tolerance
     * 
     * @param index inconsistent index
     *
     */
126
    void registerInconsistentForceIndex(int index, int value = 1);
127
128
129
130
131

    /**
     * Clear list of entries that differ by a specified tolerance
     * 
     */
132
    void clearInconsistentForceIndexList();
133
134
135
136
137

    /**
     * Get list of entries that differ by a specified tolerance
     * 
     */
138
    void getInconsistentForceIndexList(std::vector<int>& inconsistentIndices) const;
139
140
141
142
143

    /**
     * Get number of entries in inconsistent index list
     * 
     */
144
    int getNumberOfInconsistentForceEntries() const;
145
146
147
148
149
150

    /**
     * Return true if nans were detected
     * 
     * @return true if nans were detected
     */
151
    int nansDetected() const;
152
153
154
155
156
157

    /**
     * Determine if force norms are valid
     * 
     * @param tolerance              tolerance
     */
158
    void compareForceNorms(double tolerance);
159
160
161
162
163
164

    /**
     * Determine if forces are valid
     * 
     * @param tolerance   tolerance
     */
165
    void compareForces(double tolerance);
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

private:

    // computed potential energies and forces fror two platforms

    double _potentialEnergies[2];
    std::vector<Vec3> _forces[2];

    // platform and force names

    std::string _platforms[2];
    std::vector<std::string> _forceNames;

    // force norms and stat entries

    std::vector<double> _norms[2];
    std::vector<double> _normStatVectors[2];

    // map of indicies w/ inconsistent force entries

    std::map<int, int> _inconsistentForceIndicies;

    // if set, then nans detected

    int _nansDetected;

    /**
     * Calculate norms of vectors
     * 
     */
196
    void _calculateNorms();
197
198
199
200
201

    /**
     * Calculate norms of specified vector
     * 
     */
202
    void _calculateNormOfForceVector(int forceIndex);
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    // stat indices

    static const int STAT_AVG = 0;
    static const int STAT_STD = 1;
    static const int STAT_MIN = 2;
    static const int STAT_ID1 = 3;
    static const int STAT_MAX = 4;
    static const int STAT_ID2 = 5;
    static const int STAT_CNT = 6;
    static const int STAT_END = 7;

    /**
     * Find vector stats
     * 
     */
219
    void _findStatsForDouble(const std::vector<double>& array, std::vector<double>& statVector) const;
220
221
222
223
224
225
226
};

// Class used to compare forces/potential energies on two platforms

class ValidateOpenMMForces : public ValidateOpenMM {
public:

227
    OPENMM_VALIDATE_EXPORT ValidateOpenMMForces();
228
    OPENMM_VALIDATE_EXPORT ~ValidateOpenMMForces();
229
230
231
232
233
234
235
236
237
238

    /**
     * Validate force/energy by comparing the results between the forces/energies computed on user-provided context platform
     * with Reference platform
     * 
     * @param context          context reference
     * @param summaryString    output summary string of results of comparison (optional)
     *
     * @return number of inconsistent entries
     */
239
     int OPENMM_VALIDATE_EXPORT compareWithReferencePlatform(Context& context, std::string* summaryString = NULL);
240
241
242
243
244
245
246
247
248
249
250
251
252

    /**
     * Validate force/energy by comparing the results between the forces/energies computed on two different platforms
     * 
     * @param context          context reference
     * @param compareForces    indices of force to be tested
     * @param platform1        first platform to compute forces
     * @param platform2        second platform to compute forces
     *
     * @return ForceValidationResult reference containing results of force/energy computations
     *         on the two input platforms
     */
     ForceValidationResult* compareForce(Context& context, std::vector<int>& compareForces,
253
                                         Platform& platform1, Platform& platform2) const;
254
255
256
257
258
259
260
261
262
263

    /**
     * Compare individual forces by comparing calculations across two platforms (platform associated w/ input context and
     * comparisonPlatform)
     * 
     * @param context                 context reference
     * @param platform                comparsion platform reference 
     * @param forceValidationResults  output vector of ForceValidationResult ptrs (user is responsible for deleting
     *                                individual ForceValidationResult objects)
     */
264
    void compareOpenMMForces(Context& context, Platform& comparisonPlatform, std::vector<ForceValidationResult*>& forceValidationResults) const;
265
266
267
268
269
270

    /**
     * Determine if results are consistent
     * 
     * @param forceValidationResults  vector of ForceValidationResult ptrs to check if forces are consistent
     */
271
    void checkForInconsistentForceEntries(std::vector<ForceValidationResult*>& forceValidationResults) const;
272
273
274
275
276
277

    /**
     * Get total number of force entries that are inconsistent
     * 
     * @param forceValidationResults  vector of ForceValidationResult ptrs to check if forces are consistent
     */
278
    int getTotalNumberOfInconsistentForceEntries(std::vector<ForceValidationResult*>& forceValidationResults) const;
279
280
281
282
283
284

    /**
     * Get summary string of results
     * 
     * @param forceValidationResults  vector of ForceValidationResult ptrs
     */
285
    std::string getSummary(std::vector<ForceValidationResult*>& forceValidationResults) const;
286
287
288
289
290
291

    /**
     * Set force tolerance
     * 
     * @param tolerance     force tolerance
     */
292
    void setForceTolerance(double tolerance);
293
294
295
296
297
298

    /**
     * Get force tolerance
     * 
     * @return force tolerance
     */
299
    double getForceTolerance() const;
300
301
302
303
304
305
306
307
308
309

    /* 
     * Get force tolerance for specified force
     *
     * @param forceName   name of force
     *
     * @return force tolerance
     *
     * */
      
310
    double getForceTolerance(const std::string& forceName) const;
311
312
313
314
315
316
317
318
       
    /* 
     * Get max errors to print in summary string
     *
     * @return max errors to print
     *
     * */
      
319
    int getMaxErrorsToPrint() const;
320
321
322
323
324
325
326
327
       
    /* 
     * Set max errors to print in summary string
     *
     * @param maxErrorsToPrint max errors to print
     *
     * */
      
328
    void setMaxErrorsToPrint(int maxErrorsToPrint);
329
330
331
332
333
334
335
336
337
       
    /* 
     * Return true if force is not to be validated (Andersen thermostat, CM motion remover, ...)
     *
     * @param forceName   force name
     *
     * @return true if force is not currently validated
     **/
      
338
    int isExcludedForce(std::string forceName) const;
339
340
341
342
343
       
private:

     // initialize class entries

344
     void _initialize();
345
346
347
348
349
350
351
352
353
354
355
356

      /* 
       * Format output line
       *
       * @param tab         tab
       * @param description description
       * @param value       value
       *
       * @return string containing contents
       *
       * */
      
357
358
359
      std::string _getLine(const std::string& tab,
                           const std::string& description,
                           const std::string& value) const;
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
      
     std::vector<ForceValidationResult*> _forceValidationResults;

     // max errors to print

     int _maxErrorsToPrint;

     // tolerence

     double _forceTolerance;

     // map of force tolerances to type (name)

     StringDoubleMap _forceTolerances;
     
     // forces to be excluded from validation

     StringIntMap _forcesToBeExcluded;
};

} // namespace OpenMM

#endif /*VALIDATE_OPENMM_FORCES_H_*/