Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
17b61225
Unverified
Commit
17b61225
authored
Jan 31, 2023
by
Peter Eastman
Committed by
GitHub
Jan 31, 2023
Browse files
Use CompiledExpression for CustomCVForce energy expression (#3898)
parent
b57a5a63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
167 additions
and
107 deletions
+167
-107
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+7
-3
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+75
-50
platforms/reference/include/ReferenceCustomCVForce.h
platforms/reference/include/ReferenceCustomCVForce.h
+11
-7
platforms/reference/src/SimTKReference/ReferenceCustomCVForce.cpp
...s/reference/src/SimTKReference/ReferenceCustomCVForce.cpp
+74
-47
No files found.
platforms/common/include/openmm/common/CommonKernels.h
View file @
17b61225
...
@@ -909,6 +909,7 @@ public:
...
@@ -909,6 +909,7 @@ public:
CommonCalcCustomCVForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
ComputeContext
&
cc
)
:
CalcCustomCVForceKernel
(
name
,
platform
),
CommonCalcCustomCVForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
ComputeContext
&
cc
)
:
CalcCustomCVForceKernel
(
name
,
platform
),
cc
(
cc
),
hasInitializedListeners
(
false
)
{
cc
(
cc
),
hasInitializedListeners
(
false
)
{
}
}
~
CommonCalcCustomCVForceKernel
();
/**
/**
* Initialize the kernel.
* Initialize the kernel.
*
*
...
@@ -948,13 +949,16 @@ public:
...
@@ -948,13 +949,16 @@ public:
private:
private:
class
ForceInfo
;
class
ForceInfo
;
class
ReorderListener
;
class
ReorderListener
;
class
TabulatedFunctionWrapper
;
ComputeContext
&
cc
;
ComputeContext
&
cc
;
bool
hasInitializedListeners
;
bool
hasInitializedListeners
;
Lepton
::
Expression
Program
energyExpression
;
Lepton
::
Compiled
Expression
energyExpression
;
std
::
vector
<
std
::
string
>
variableNames
,
paramDerivNames
,
globalParameterNames
;
std
::
vector
<
std
::
string
>
variableNames
,
paramDerivNames
,
globalParameterNames
;
std
::
vector
<
Lepton
::
Expression
Program
>
variableDerivExpressions
;
std
::
vector
<
Lepton
::
Compiled
Expression
>
variableDerivExpressions
;
std
::
vector
<
Lepton
::
Expression
Program
>
paramDerivExpressions
;
std
::
vector
<
Lepton
::
Compiled
Expression
>
paramDerivExpressions
;
std
::
vector
<
ComputeArray
>
cvForces
;
std
::
vector
<
ComputeArray
>
cvForces
;
std
::
vector
<
double
>
globalValues
,
cvValues
;
std
::
vector
<
Lepton
::
CustomFunction
*>
tabulatedFunctions
;
ComputeArray
invAtomOrder
;
ComputeArray
invAtomOrder
;
ComputeArray
innerInvAtomOrder
;
ComputeArray
innerInvAtomOrder
;
ComputeKernel
copyStateKernel
,
copyForcesKernel
,
addForcesKernel
;
ComputeKernel
copyStateKernel
,
copyForcesKernel
,
addForcesKernel
;
...
...
platforms/common/src/CommonKernels.cpp
View file @
17b61225
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2008-202
2
Stanford University and the Authors. *
* Portions copyright (c) 2008-202
3
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -98,15 +98,6 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
...
@@ -98,15 +98,6 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
return
make_pair
(
ExpressionTreeNode
(
new
Operation
::
Variable
(
name
)),
value
);
return
make_pair
(
ExpressionTreeNode
(
new
Operation
::
Variable
(
name
)),
value
);
}
}
static
void
replaceFunctionsInExpression
(
map
<
string
,
CustomFunction
*>&
functions
,
ExpressionProgram
&
expression
)
{
for
(
int
i
=
0
;
i
<
expression
.
getNumOperations
();
i
++
)
{
if
(
expression
.
getOperation
(
i
).
getId
()
==
Operation
::
CUSTOM
)
{
const
Operation
::
Custom
&
op
=
dynamic_cast
<
const
Operation
::
Custom
&>
(
expression
.
getOperation
(
i
));
expression
.
setOperation
(
i
,
new
Operation
::
Custom
(
op
.
getName
(),
functions
[
op
.
getName
()]
->
clone
(),
op
.
getDerivOrder
()));
}
}
}
void
CommonApplyConstraintsKernel
::
initialize
(
const
System
&
system
)
{
void
CommonApplyConstraintsKernel
::
initialize
(
const
System
&
system
)
{
}
}
...
@@ -5136,6 +5127,30 @@ private:
...
@@ -5136,6 +5127,30 @@ private:
ArrayInterface
&
invAtomOrder
;
ArrayInterface
&
invAtomOrder
;
};
};
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class
CommonCalcCustomCVForceKernel
::
TabulatedFunctionWrapper
:
public
CustomFunction
{
public:
TabulatedFunctionWrapper
(
vector
<
Lepton
::
CustomFunction
*>&
tabulatedFunctions
,
int
index
)
:
tabulatedFunctions
(
tabulatedFunctions
),
index
(
index
)
{
}
int
getNumArguments
()
const
{
return
tabulatedFunctions
[
index
]
->
getNumArguments
();
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
tabulatedFunctions
[
index
]
->
evaluate
(
arguments
);
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
tabulatedFunctions
[
index
]
->
evaluateDerivative
(
arguments
,
derivOrder
);
}
CustomFunction
*
clone
()
const
{
return
new
TabulatedFunctionWrapper
(
tabulatedFunctions
,
index
);
}
private:
vector
<
Lepton
::
CustomFunction
*>&
tabulatedFunctions
;
int
index
;
};
void
CommonCalcCustomCVForceKernel
::
initialize
(
const
System
&
system
,
const
CustomCVForce
&
force
,
ContextImpl
&
innerContext
)
{
void
CommonCalcCustomCVForceKernel
::
initialize
(
const
System
&
system
,
const
CustomCVForce
&
force
,
ContextImpl
&
innerContext
)
{
ContextSelector
selector
(
cc
);
ContextSelector
selector
(
cc
);
int
numCVs
=
force
.
getNumCollectiveVariables
();
int
numCVs
=
force
.
getNumCollectiveVariables
();
...
@@ -5152,19 +5167,34 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
...
@@ -5152,19 +5167,34 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
// Create custom functions for the tabulated functions.
// Create custom functions for the tabulated functions.
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
for
(
int
i
=
0
;
i
<
(
int
)
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunctions
.
resize
(
force
.
getNumTabulatedFunctions
(),
NULL
);
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
tabulatedFunctions
[
i
]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
new
TabulatedFunctionWrapper
(
tabulatedFunctions
,
i
);
}
// Create the expressions.
// Create the expressions.
Lepton
::
ParsedExpression
energyExpr
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
);
Lepton
::
ParsedExpression
energyExpr
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
)
.
optimize
()
;
energyExpression
=
energyExpr
.
create
Program
();
energyExpression
=
energyExpr
.
create
CompiledExpression
();
variableDerivExpressions
.
clear
();
variableDerivExpressions
.
clear
();
for
(
auto
&
name
:
variableNames
)
for
(
auto
&
name
:
variableNames
)
variableDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
optimize
().
createProgram
());
variableDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
createCompiledExpression
());
paramDerivExpressions
.
clear
();
paramDerivExpressions
.
clear
();
for
(
auto
&
name
:
paramDerivNames
)
for
(
auto
&
name
:
paramDerivNames
)
paramDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
optimize
().
createProgram
());
paramDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
createCompiledExpression
());
globalValues
.
resize
(
globalParameterNames
.
size
());
cvValues
.
resize
(
numCVs
);
map
<
string
,
double
*>
variableLocations
;
for
(
int
i
=
0
;
i
<
globalParameterNames
.
size
();
i
++
)
variableLocations
[
globalParameterNames
[
i
]]
=
&
globalValues
[
i
];
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
variableLocations
[
variableNames
[
i
]]
=
&
cvValues
[
i
];
energyExpression
.
setVariableLocations
(
variableLocations
);
for
(
CompiledExpression
&
expr
:
variableDerivExpressions
)
expr
.
setVariableLocations
(
variableLocations
);
for
(
CompiledExpression
&
expr
:
paramDerivExpressions
)
expr
.
setVariableLocations
(
variableLocations
);
// Delete the custom functions.
// Delete the custom functions.
...
@@ -5229,15 +5259,20 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
...
@@ -5229,15 +5259,20 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
cc
.
addForce
(
new
ForceInfo
(
*
info
));
cc
.
addForce
(
new
ForceInfo
(
*
info
));
}
}
CommonCalcCustomCVForceKernel
::~
CommonCalcCustomCVForceKernel
()
{
for
(
int
i
=
0
;
i
<
tabulatedFunctions
.
size
();
i
++
)
if
(
tabulatedFunctions
[
i
]
!=
NULL
)
delete
tabulatedFunctions
[
i
];
}
double
CommonCalcCustomCVForceKernel
::
execute
(
ContextImpl
&
context
,
ContextImpl
&
innerContext
,
bool
includeForces
,
bool
includeEnergy
)
{
double
CommonCalcCustomCVForceKernel
::
execute
(
ContextImpl
&
context
,
ContextImpl
&
innerContext
,
bool
includeForces
,
bool
includeEnergy
)
{
copyState
(
context
,
innerContext
);
copyState
(
context
,
innerContext
);
int
numCVs
=
variableNames
.
size
();
int
numCVs
=
variableNames
.
size
();
int
numAtoms
=
cc
.
getNumAtoms
();
int
numAtoms
=
cc
.
getNumAtoms
();
int
paddedNumAtoms
=
cc
.
getPaddedNumAtoms
();
int
paddedNumAtoms
=
cc
.
getPaddedNumAtoms
();
vector
<
double
>
cvValues
;
vector
<
map
<
string
,
double
>
>
cvDerivs
(
numCVs
);
vector
<
map
<
string
,
double
>
>
cvDerivs
(
numCVs
);
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
cvValues
.
push_back
(
innerContext
.
calcForcesAndEnergy
(
true
,
true
,
1
<<
i
)
)
;
cvValues
[
i
]
=
innerContext
.
calcForcesAndEnergy
(
true
,
true
,
1
<<
i
);
ContextSelector
selector
(
cc
);
ContextSelector
selector
(
cc
);
copyForcesKernel
->
setArg
(
0
,
cvForces
[
i
]);
copyForcesKernel
->
setArg
(
0
,
cvForces
[
i
]);
copyForcesKernel
->
execute
(
numAtoms
);
copyForcesKernel
->
execute
(
numAtoms
);
...
@@ -5247,14 +5282,11 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
...
@@ -5247,14 +5282,11 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
// Compute the energy and forces.
// Compute the energy and forces.
ContextSelector
selector
(
cc
);
ContextSelector
selector
(
cc
);
map
<
string
,
double
>
variables
;
for
(
int
i
=
0
;
i
<
globalParameterNames
.
size
();
i
++
)
for
(
auto
&
name
:
globalParameterNames
)
globalValues
[
i
]
=
context
.
getParameter
(
globalParameterNames
[
i
]);
variables
[
name
]
=
context
.
getParameter
(
name
);
double
energy
=
energyExpression
.
evaluate
();
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
variables
[
variableNames
[
i
]]
=
cvValues
[
i
];
double
energy
=
energyExpression
.
evaluate
(
variables
);
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
(
variables
);
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
();
addForcesKernel
->
setArg
(
2
*
i
+
2
,
cvForces
[
i
]);
addForcesKernel
->
setArg
(
2
*
i
+
2
,
cvForces
[
i
]);
if
(
cc
.
getUseDoublePrecision
())
if
(
cc
.
getUseDoublePrecision
())
addForcesKernel
->
setArg
(
2
*
i
+
3
,
dEdV
);
addForcesKernel
->
setArg
(
2
*
i
+
3
,
dEdV
);
...
@@ -5262,16 +5294,18 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
...
@@ -5262,16 +5294,18 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
addForcesKernel
->
setArg
(
2
*
i
+
3
,
(
float
)
dEdV
);
addForcesKernel
->
setArg
(
2
*
i
+
3
,
(
float
)
dEdV
);
}
}
addForcesKernel
->
execute
(
numAtoms
);
addForcesKernel
->
execute
(
numAtoms
);
// Compute the energy parameter derivatives.
// Compute the energy parameter derivatives.
map
<
string
,
double
>&
energyParamDerivs
=
cc
.
getEnergyParamDerivWorkspace
();
if
(
paramDerivExpressions
.
size
()
>
0
)
{
for
(
int
i
=
0
;
i
<
paramDerivExpressions
.
size
();
i
++
)
map
<
string
,
double
>&
energyParamDerivs
=
cc
.
getEnergyParamDerivWorkspace
();
energyParamDerivs
[
paramDerivNames
[
i
]]
+=
paramDerivExpressions
[
i
].
evaluate
(
variables
);
for
(
int
i
=
0
;
i
<
paramDerivExpressions
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
energyParamDerivs
[
paramDerivNames
[
i
]]
+=
paramDerivExpressions
[
i
].
evaluate
();
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
(
variables
);
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
auto
&
deriv
:
cvDerivs
[
i
])
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
();
energyParamDerivs
[
deriv
.
first
]
+=
dEdV
*
deriv
.
second
;
for
(
auto
&
deriv
:
cvDerivs
[
i
])
energyParamDerivs
[
deriv
.
first
]
+=
dEdV
*
deriv
.
second
;
}
}
}
return
energy
;
return
energy
;
}
}
...
@@ -5305,22 +5339,13 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
...
@@ -5305,22 +5339,13 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
void
CommonCalcCustomCVForceKernel
::
copyParametersToContext
(
ContextImpl
&
context
,
const
CustomCVForce
&
force
)
{
void
CommonCalcCustomCVForceKernel
::
copyParametersToContext
(
ContextImpl
&
context
,
const
CustomCVForce
&
force
)
{
// Create custom functions for the tabulated functions.
// Create custom functions for the tabulated functions.
map
<
string
,
CustomFunction
*>
functions
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
(
int
)
force
.
getNumTabulatedFunctions
();
i
++
)
if
(
tabulatedFunctions
[
i
]
!=
NULL
)
{
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
delete
tabulatedFunctions
[
i
];
tabulatedFunctions
[
i
]
=
NULL
;
// Replace tabulated functions in the expressions.
}
tabulatedFunctions
[
i
]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
replaceFunctionsInExpression
(
functions
,
energyExpression
);
}
for
(
auto
&
expression
:
variableDerivExpressions
)
replaceFunctionsInExpression
(
functions
,
expression
);
for
(
auto
&
expression
:
paramDerivExpressions
)
replaceFunctionsInExpression
(
functions
,
expression
);
// Delete the custom functions.
for
(
auto
&
function
:
functions
)
delete
function
.
second
;
}
}
void
CommonIntegrateVerletStepKernel
::
initialize
(
const
System
&
system
,
const
VerletIntegrator
&
integrator
)
{
void
CommonIntegrateVerletStepKernel
::
initialize
(
const
System
&
system
,
const
VerletIntegrator
&
integrator
)
{
...
...
platforms/reference/include/ReferenceCustomCVForce.h
View file @
17b61225
/* Portions copyright (c) 2017 Stanford University and Simbios.
/* Portions copyright (c) 2017
-2023
Stanford University and Simbios.
* Contributors: Peter Eastman
* Contributors: Peter Eastman
*
*
* Permission is hereby granted, free of charge, to any person obtaining
* Permission is hereby granted, free of charge, to any person obtaining
...
@@ -27,7 +27,8 @@
...
@@ -27,7 +27,8 @@
#include "openmm/CustomCVForce.h"
#include "openmm/CustomCVForce.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/ContextImpl.h"
#include "lepton/ExpressionProgram.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include <map>
#include <map>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -36,10 +37,13 @@ namespace OpenMM {
...
@@ -36,10 +37,13 @@ namespace OpenMM {
class
ReferenceCustomCVForce
{
class
ReferenceCustomCVForce
{
private:
private:
Lepton
::
ExpressionProgram
energyExpression
;
class
TabulatedFunctionWrapper
;
std
::
vector
<
std
::
string
>
variableNames
,
paramDerivNames
;
Lepton
::
CompiledExpression
energyExpression
;
std
::
vector
<
Lepton
::
ExpressionProgram
>
variableDerivExpressions
;
std
::
vector
<
std
::
string
>
variableNames
,
paramDerivNames
,
globalParameterNames
;
std
::
vector
<
Lepton
::
ExpressionProgram
>
paramDerivExpressions
;
std
::
vector
<
Lepton
::
CompiledExpression
>
variableDerivExpressions
;
std
::
vector
<
Lepton
::
CompiledExpression
>
paramDerivExpressions
;
std
::
vector
<
double
>
globalValues
,
cvValues
;
std
::
vector
<
Lepton
::
CustomFunction
*>
tabulatedFunctions
;
public:
public:
/**
/**
...
@@ -70,7 +74,7 @@ public:
...
@@ -70,7 +74,7 @@ public:
*/
*/
void
calculateIxn
(
ContextImpl
&
innerContext
,
std
::
vector
<
OpenMM
::
Vec3
>&
atomCoordinates
,
void
calculateIxn
(
ContextImpl
&
innerContext
,
std
::
vector
<
OpenMM
::
Vec3
>&
atomCoordinates
,
const
std
::
map
<
std
::
string
,
double
>&
globalParameters
,
const
std
::
map
<
std
::
string
,
double
>&
globalParameters
,
std
::
vector
<
OpenMM
::
Vec3
>&
forces
,
double
*
totalEnergy
,
std
::
map
<
std
::
string
,
double
>&
energyParamDerivs
)
const
;
std
::
vector
<
OpenMM
::
Vec3
>&
forces
,
double
*
totalEnergy
,
std
::
map
<
std
::
string
,
double
>&
energyParamDerivs
);
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
platforms/reference/src/SimTKReference/ReferenceCustomCVForce.cpp
View file @
17b61225
/* Portions copyright (c) 2009-20
17
Stanford University and Simbios.
/* Portions copyright (c) 2009-20
23
Stanford University and Simbios.
* Contributors: Peter Eastman
* Contributors: Peter Eastman
*
*
* Permission is hereby granted, free of charge, to any person obtaining
* Permission is hereby granted, free of charge, to any person obtaining
...
@@ -34,8 +34,35 @@ using namespace OpenMM;
...
@@ -34,8 +34,35 @@ using namespace OpenMM;
using
namespace
Lepton
;
using
namespace
Lepton
;
using
namespace
std
;
using
namespace
std
;
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class
ReferenceCustomCVForce
::
TabulatedFunctionWrapper
:
public
CustomFunction
{
public:
TabulatedFunctionWrapper
(
vector
<
Lepton
::
CustomFunction
*>&
tabulatedFunctions
,
int
index
)
:
tabulatedFunctions
(
tabulatedFunctions
),
index
(
index
)
{
}
int
getNumArguments
()
const
{
return
tabulatedFunctions
[
index
]
->
getNumArguments
();
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
tabulatedFunctions
[
index
]
->
evaluate
(
arguments
);
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
tabulatedFunctions
[
index
]
->
evaluateDerivative
(
arguments
,
derivOrder
);
}
CustomFunction
*
clone
()
const
{
return
new
TabulatedFunctionWrapper
(
tabulatedFunctions
,
index
);
}
private:
vector
<
Lepton
::
CustomFunction
*>&
tabulatedFunctions
;
int
index
;
};
ReferenceCustomCVForce
::
ReferenceCustomCVForce
(
const
CustomCVForce
&
force
)
{
ReferenceCustomCVForce
::
ReferenceCustomCVForce
(
const
CustomCVForce
&
force
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumCollectiveVariables
();
i
++
)
int
numCVs
=
force
.
getNumCollectiveVariables
();
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
globalParameterNames
.
push_back
(
force
.
getGlobalParameterName
(
i
));
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
variableNames
.
push_back
(
force
.
getCollectiveVariableName
(
i
));
variableNames
.
push_back
(
force
.
getCollectiveVariableName
(
i
));
for
(
int
i
=
0
;
i
<
force
.
getNumEnergyParameterDerivatives
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumEnergyParameterDerivatives
();
i
++
)
paramDerivNames
.
push_back
(
force
.
getEnergyParameterDerivativeName
(
i
));
paramDerivNames
.
push_back
(
force
.
getEnergyParameterDerivativeName
(
i
));
...
@@ -43,19 +70,34 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
...
@@ -43,19 +70,34 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
// Create custom functions for the tabulated functions.
// Create custom functions for the tabulated functions.
map
<
string
,
CustomFunction
*>
functions
;
map
<
string
,
CustomFunction
*>
functions
;
for
(
int
i
=
0
;
i
<
(
int
)
force
.
getNumTabulatedFunctions
();
i
++
)
tabulatedFunctions
.
resize
(
force
.
getNumTabulatedFunctions
(),
NULL
);
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
tabulatedFunctions
[
i
]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
new
TabulatedFunctionWrapper
(
tabulatedFunctions
,
i
);
}
// Create the expressions.
// Create the expressions.
ParsedExpression
energyExpr
=
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
);
ParsedExpression
energyExpr
=
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
)
.
optimize
()
;
energyExpression
=
energyExpr
.
create
Program
();
energyExpression
=
energyExpr
.
create
CompiledExpression
();
variableDerivExpressions
.
clear
();
variableDerivExpressions
.
clear
();
for
(
auto
&
name
:
variableNames
)
for
(
auto
&
name
:
variableNames
)
variableDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
optimize
().
createProgram
());
variableDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
createCompiledExpression
());
paramDerivExpressions
.
clear
();
paramDerivExpressions
.
clear
();
for
(
auto
&
name
:
paramDerivNames
)
for
(
auto
&
name
:
paramDerivNames
)
paramDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
optimize
().
createProgram
());
paramDerivExpressions
.
push_back
(
energyExpr
.
differentiate
(
name
).
createCompiledExpression
());
globalValues
.
resize
(
variableNames
.
size
());
cvValues
.
resize
(
numCVs
);
map
<
string
,
double
*>
variableLocations
;
for
(
int
i
=
0
;
i
<
globalParameterNames
.
size
();
i
++
)
variableLocations
[
globalParameterNames
[
i
]]
=
&
globalValues
[
i
];
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
variableLocations
[
variableNames
[
i
]]
=
&
cvValues
[
i
];
energyExpression
.
setVariableLocations
(
variableLocations
);
for
(
CompiledExpression
&
expr
:
variableDerivExpressions
)
expr
.
setVariableLocations
(
variableLocations
);
for
(
CompiledExpression
&
expr
:
paramDerivExpressions
)
expr
.
setVariableLocations
(
variableLocations
);
// Delete the custom functions.
// Delete the custom functions.
...
@@ -63,78 +105,63 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
...
@@ -63,78 +105,63 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
delete
function
.
second
;
delete
function
.
second
;
}
}
static
void
replaceFunctionsInExpression
(
map
<
string
,
CustomFunction
*>&
functions
,
ExpressionProgram
&
expression
)
{
for
(
int
i
=
0
;
i
<
expression
.
getNumOperations
();
i
++
)
{
if
(
expression
.
getOperation
(
i
).
getId
()
==
Operation
::
CUSTOM
)
{
const
Operation
::
Custom
&
op
=
dynamic_cast
<
const
Operation
::
Custom
&>
(
expression
.
getOperation
(
i
));
expression
.
setOperation
(
i
,
new
Operation
::
Custom
(
op
.
getName
(),
functions
[
op
.
getName
()]
->
clone
(),
op
.
getDerivOrder
()));
}
}
}
void
ReferenceCustomCVForce
::
updateTabulatedFunctions
(
const
OpenMM
::
CustomCVForce
&
force
)
{
void
ReferenceCustomCVForce
::
updateTabulatedFunctions
(
const
OpenMM
::
CustomCVForce
&
force
)
{
// Create custom functions for the tabulated functions.
// Create custom functions for the tabulated functions.
map
<
string
,
CustomFunction
*>
functions
;
for
(
int
i
=
0
;
i
<
force
.
getNumTabulatedFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
(
int
)
force
.
getNumTabulatedFunctions
();
i
++
)
if
(
tabulatedFunctions
[
i
]
!=
NULL
)
{
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
delete
tabulatedFunctions
[
i
];
tabulatedFunctions
[
i
]
=
NULL
;
// Replace tabulated functions in the expressions.
}
tabulatedFunctions
[
i
]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
replaceFunctionsInExpression
(
functions
,
energyExpression
);
}
for
(
auto
&
expression
:
variableDerivExpressions
)
replaceFunctionsInExpression
(
functions
,
expression
);
for
(
auto
&
expression
:
paramDerivExpressions
)
replaceFunctionsInExpression
(
functions
,
expression
);
// Delete the custom functions.
for
(
auto
&
function
:
functions
)
delete
function
.
second
;
}
}
ReferenceCustomCVForce
::~
ReferenceCustomCVForce
()
{
ReferenceCustomCVForce
::~
ReferenceCustomCVForce
()
{
for
(
int
i
=
0
;
i
<
tabulatedFunctions
.
size
();
i
++
)
if
(
tabulatedFunctions
[
i
]
!=
NULL
)
delete
tabulatedFunctions
[
i
];
}
}
void
ReferenceCustomCVForce
::
calculateIxn
(
ContextImpl
&
innerContext
,
vector
<
Vec3
>&
atomCoordinates
,
void
ReferenceCustomCVForce
::
calculateIxn
(
ContextImpl
&
innerContext
,
vector
<
Vec3
>&
atomCoordinates
,
const
map
<
string
,
double
>&
globalParameters
,
vector
<
Vec3
>&
forces
,
const
map
<
string
,
double
>&
globalParameters
,
vector
<
Vec3
>&
forces
,
double
*
totalEnergy
,
map
<
string
,
double
>&
energyParamDerivs
)
const
{
double
*
totalEnergy
,
map
<
string
,
double
>&
energyParamDerivs
)
{
// Compute the collective variables, and their derivatives with respect to particle positions.
// Compute the collective variables, and their derivatives with respect to particle positions.
int
numCVs
=
variableNames
.
size
();
int
numCVs
=
variableNames
.
size
();
ReferencePlatform
::
PlatformData
*
data
=
reinterpret_cast
<
ReferencePlatform
::
PlatformData
*>
(
innerContext
.
getPlatformData
());
ReferencePlatform
::
PlatformData
*
data
=
reinterpret_cast
<
ReferencePlatform
::
PlatformData
*>
(
innerContext
.
getPlatformData
());
vector
<
Vec3
>&
innerForces
=
*
((
vector
<
Vec3
>*
)
data
->
forces
);
vector
<
Vec3
>&
innerForces
=
*
((
vector
<
Vec3
>*
)
data
->
forces
);
map
<
string
,
double
>&
innerDerivs
=
*
((
map
<
string
,
double
>*
)
data
->
energyParameterDerivatives
);
map
<
string
,
double
>&
innerDerivs
=
*
((
map
<
string
,
double
>*
)
data
->
energyParameterDerivatives
);
vector
<
double
>
cvValues
;
vector
<
vector
<
Vec3
>
>
cvForces
;
vector
<
vector
<
Vec3
>
>
cvForces
;
vector
<
map
<
string
,
double
>
>
cvDerivs
;
vector
<
map
<
string
,
double
>
>
cvDerivs
;
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
cvValues
.
push_back
(
innerContext
.
calcForcesAndEnergy
(
true
,
true
,
1
<<
i
)
)
;
cvValues
[
i
]
=
innerContext
.
calcForcesAndEnergy
(
true
,
true
,
1
<<
i
);
cvForces
.
push_back
(
innerForces
);
cvForces
.
push_back
(
innerForces
);
cvDerivs
.
push_back
(
innerDerivs
);
cvDerivs
.
push_back
(
innerDerivs
);
}
}
// Compute the energy and forces.
// Compute the energy and forces.
for
(
int
i
=
0
;
i
<
globalParameterNames
.
size
();
i
++
)
globalValues
[
i
]
=
globalParameters
.
at
(
globalParameterNames
[
i
]);
int
numParticles
=
atomCoordinates
.
size
();
int
numParticles
=
atomCoordinates
.
size
();
map
<
string
,
double
>
variables
=
globalParameters
;
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
variables
[
variableNames
[
i
]]
=
cvValues
[
i
];
if
(
totalEnergy
!=
NULL
)
if
(
totalEnergy
!=
NULL
)
*
totalEnergy
+=
energyExpression
.
evaluate
(
variables
);
*
totalEnergy
+=
energyExpression
.
evaluate
();
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
(
variables
);
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
();
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
forces
[
j
]
+=
cvForces
[
i
][
j
]
*
dEdV
;
forces
[
j
]
+=
cvForces
[
i
][
j
]
*
dEdV
;
}
}
// Compute the energy parameter derivatives.
// Compute the energy parameter derivatives.
for
(
int
i
=
0
;
i
<
paramDerivExpressions
.
size
();
i
++
)
if
(
paramDerivExpressions
.
size
()
>
0
)
{
energyParamDerivs
[
paramDerivNames
[
i
]]
+=
paramDerivExpressions
[
i
].
evaluate
(
variables
);
for
(
int
i
=
0
;
i
<
paramDerivExpressions
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
energyParamDerivs
[
paramDerivNames
[
i
]]
+=
paramDerivExpressions
[
i
].
evaluate
();
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
(
variables
);
for
(
int
i
=
0
;
i
<
numCVs
;
i
++
)
{
for
(
auto
&
deriv
:
cvDerivs
[
i
])
double
dEdV
=
variableDerivExpressions
[
i
].
evaluate
();
energyParamDerivs
[
deriv
.
first
]
+=
dEdV
*
deriv
.
second
;
for
(
auto
&
deriv
:
cvDerivs
[
i
])
energyParamDerivs
[
deriv
.
first
]
+=
dEdV
*
deriv
.
second
;
}
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment