"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/nni.git" did not exist on "183763effecb80b47ccfe6963424e7fd269b94b7"
ReferenceCustomCVForce.cpp 7.26 KB
Newer Older
1

2
/* Portions copyright (c) 2009-2023 Stanford University and Simbios.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
 * Contributors: Peter Eastman
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to deal in the Software without restriction, including
 * without limitation the rights to use, copy, modify, merge, publish,
 * distribute, sublicense, and/or sell copies of the Software, and to
 * permit persons to whom the Software is furnished to do so, subject
 * to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included
 * in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE
 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

#include "ReferenceCustomCVForce.h"
#include "ReferencePlatform.h"
#include "ReferenceTabulatedFunction.h"
#include "lepton/CustomFunction.h"
#include "lepton/ParsedExpression.h"
#include "lepton/Parser.h"
31
#include "lepton/Operation.h"
32
33

using namespace OpenMM;
34
using namespace Lepton;
35
36
using namespace std;

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class ReferenceCustomCVForce::TabulatedFunctionWrapper : public CustomFunction {
public:
    TabulatedFunctionWrapper(vector<Lepton::CustomFunction*>& tabulatedFunctions, int index) :
            tabulatedFunctions(tabulatedFunctions), index(index) {
    }
    int getNumArguments() const {
        return tabulatedFunctions[index]->getNumArguments();
    }
    double evaluate(const double* arguments) const {
        return tabulatedFunctions[index]->evaluate(arguments);
    }
    double evaluateDerivative(const double* arguments, const int* derivOrder) const {
        return tabulatedFunctions[index]->evaluateDerivative(arguments, derivOrder);
    }
    CustomFunction* clone() const {
        return new TabulatedFunctionWrapper(tabulatedFunctions, index);
    }
private:
    vector<Lepton::CustomFunction*>& tabulatedFunctions;    
    int index;
};

61
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
62
63
64
65
    int numCVs = force.getNumCollectiveVariables();
    for (int i = 0; i < force.getNumGlobalParameters(); i++)
        globalParameterNames.push_back(force.getGlobalParameterName(i));
    for (int i = 0; i < numCVs; i++)
66
67
68
69
        variableNames.push_back(force.getCollectiveVariableName(i));
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
        paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));

70
71
    // Create custom functions for the tabulated functions.

72
    map<string, CustomFunction*> functions;
73
74
75
76
77
    tabulatedFunctions.resize(force.getNumTabulatedFunctions(), NULL);
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
        functions[force.getTabulatedFunctionName(i)] = new TabulatedFunctionWrapper(tabulatedFunctions, i);
    }
78
79
80

    // Create the expressions.

81
82
    ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions).optimize();
    energyExpression = energyExpr.createCompiledExpression();
83
84
    variableDerivExpressions.clear();
    for (auto& name : variableNames)
85
        variableDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
86
87
    paramDerivExpressions.clear();
    for (auto& name : paramDerivNames)
88
        paramDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
89
    globalValues.resize(globalParameterNames.size());
90
91
92
93
94
95
96
97
98
99
100
    cvValues.resize(numCVs);
    map<string, double*> variableLocations;
    for (int i = 0; i < globalParameterNames.size(); i++)
        variableLocations[globalParameterNames[i]] = &globalValues[i];
    for (int i = 0; i < numCVs; i++)
        variableLocations[variableNames[i]] = &cvValues[i];
    energyExpression.setVariableLocations(variableLocations);
    for (CompiledExpression& expr : variableDerivExpressions)
        expr.setVariableLocations(variableLocations);
    for (CompiledExpression& expr : paramDerivExpressions)
        expr.setVariableLocations(variableLocations);
101
102
103
104
105
106
107

    // Delete the custom functions.

    for (auto& function : functions)
        delete function.second;
}

108
109
110
void ReferenceCustomCVForce::updateTabulatedFunctions(const OpenMM::CustomCVForce& force) {
    // Create custom functions for the tabulated functions.

111
112
113
114
115
116
117
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        if (tabulatedFunctions[i] != NULL) {
            delete tabulatedFunctions[i];
            tabulatedFunctions[i] = NULL;
        }
        tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
    }
118
119
}

120
ReferenceCustomCVForce::~ReferenceCustomCVForce() {
121
122
123
    for (int i = 0; i < tabulatedFunctions.size(); i++)
        if (tabulatedFunctions[i] != NULL)
            delete tabulatedFunctions[i];
124
125
126
127
}

void ReferenceCustomCVForce::calculateIxn(ContextImpl& innerContext, vector<Vec3>& atomCoordinates,
                                          const map<string, double>& globalParameters, vector<Vec3>& forces,
128
                                          double* totalEnergy, map<string, double>& energyParamDerivs) {
129
130
131
132
133
134
135
136
137
    // Compute the collective variables, and their derivatives with respect to particle positions.
    
    int numCVs = variableNames.size();
    ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(innerContext.getPlatformData());
    vector<Vec3>& innerForces = *((vector<Vec3>*) data->forces);
    map<string, double>& innerDerivs = *((map<string, double>*) data->energyParameterDerivatives);
    vector<vector<Vec3> > cvForces;
    vector<map<string, double> > cvDerivs;
    for (int i = 0; i < numCVs; i++) {
138
        cvValues[i] = innerContext.calcForcesAndEnergy(true, true, 1<<i);
139
140
141
142
143
144
        cvForces.push_back(innerForces);
        cvDerivs.push_back(innerDerivs);
    }
    
    // Compute the energy and forces.
    
145
146
    for (int i = 0; i < globalParameterNames.size(); i++)
        globalValues[i] = globalParameters.at(globalParameterNames[i]);
147
148
    int numParticles = atomCoordinates.size();
    if (totalEnergy != NULL)
149
        *totalEnergy += energyExpression.evaluate();
150
    for (int i = 0; i < numCVs; i++) {
151
        double dEdV = variableDerivExpressions[i].evaluate();
152
153
154
155
156
157
        for (int j = 0; j < numParticles; j++)
            forces[j] += cvForces[i][j]*dEdV;
    }
    
    // Compute the energy parameter derivatives.
    
158
159
160
161
162
163
164
165
    if (paramDerivExpressions.size() > 0) {
        for (int i = 0; i < paramDerivExpressions.size(); i++)
            energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
        for (int i = 0; i < numCVs; i++) {
            double dEdV = variableDerivExpressions[i].evaluate();
            for (auto& deriv : cvDerivs[i])
                energyParamDerivs[deriv.first] += dEdV*deriv.second;
        }
166
167
    }
}