Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
897fb788
"plugins/drude/vscode:/vscode.git/clone" did not exist on "14f8b06118270706a88e2f5498d603f19d8c5e1d"
Commit
897fb788
authored
May 24, 2017
by
Peter Eastman
Browse files
Bug fix to custom integrator
parent
b317bd38
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
15 deletions
+73
-15
openmmapi/include/openmm/internal/ContextImpl.h
openmmapi/include/openmm/internal/ContextImpl.h
+4
-1
openmmapi/src/ContextImpl.cpp
openmmapi/src/ContextImpl.cpp
+1
-1
openmmapi/src/CustomIntegratorUtilities.cpp
openmmapi/src/CustomIntegratorUtilities.cpp
+2
-1
platforms/cuda/src/CudaKernels.cpp
platforms/cuda/src/CudaKernels.cpp
+9
-6
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+9
-6
tests/TestCustomIntegrator.h
tests/TestCustomIntegrator.h
+48
-0
No files found.
openmmapi/include/openmm/internal/ContextImpl.h
View file @
897fb788
...
@@ -197,8 +197,11 @@ public:
...
@@ -197,8 +197,11 @@ public:
double
calcForcesAndEnergy
(
bool
includeForces
,
bool
includeEnergy
,
int
groups
=
0xFFFFFFFF
);
double
calcForcesAndEnergy
(
bool
includeForces
,
bool
includeEnergy
,
int
groups
=
0xFFFFFFFF
);
/**
/**
* Get the set of force group flags that were passed to the most recent call to calcForcesAndEnergy().
* Get the set of force group flags that were passed to the most recent call to calcForcesAndEnergy().
*
* Note that this returns a reference, so it's possible to modify it. Be very very cautious about
* doing that! Only do it if you're also modifying forces stored inside the context.
*/
*/
int
getLastForceGroups
()
const
;
int
&
getLastForceGroups
();
/**
/**
* Calculate the kinetic energy of the system (in kJ/mol).
* Calculate the kinetic energy of the system (in kJ/mol).
*/
*/
...
...
openmmapi/src/ContextImpl.cpp
View file @
897fb788
...
@@ -301,7 +301,7 @@ double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy,
...
@@ -301,7 +301,7 @@ double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy,
}
}
}
}
int
ContextImpl
::
getLastForceGroups
()
const
{
int
&
ContextImpl
::
getLastForceGroups
()
{
return
lastForceGroups
;
return
lastForceGroups
;
}
}
...
...
openmmapi/src/CustomIntegratorUtilities.cpp
View file @
897fb788
...
@@ -111,7 +111,8 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
...
@@ -111,7 +111,8 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
for
(
auto
&
param
:
force
->
getDefaultParameters
())
for
(
auto
&
param
:
force
->
getDefaultParameters
())
affectsForce
.
insert
(
param
.
first
);
affectsForce
.
insert
(
param
.
first
);
for
(
int
i
=
0
;
i
<
numSteps
;
i
++
)
for
(
int
i
=
0
;
i
<
numSteps
;
i
++
)
invalidatesForces
[
i
]
=
(
stepType
[
i
]
==
CustomIntegrator
::
ConstrainPositions
||
affectsForce
.
find
(
stepVariable
[
i
])
!=
affectsForce
.
end
());
invalidatesForces
[
i
]
=
(
stepType
[
i
]
==
CustomIntegrator
::
ConstrainPositions
||
stepType
[
i
]
==
CustomIntegrator
::
UpdateContextState
||
affectsForce
.
find
(
stepVariable
[
i
])
!=
affectsForce
.
end
());
// Make a list of which steps require valid forces or energy to be known.
// Make a list of which steps require valid forces or energy to be known.
...
...
platforms/cuda/src/CudaKernels.cpp
View file @
897fb788
...
@@ -7604,13 +7604,15 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
...
@@ -7604,13 +7604,15 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
int nextStep = step+1;
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups();
int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
// forces in case we need them again.
cu.getForce().copyTo(*savedForces[lastForceGroups]);
cu.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
validSavedForces.insert(lastForceGroups);
}
}
}
else
else
validSavedForces.clear();
validSavedForces.clear();
...
@@ -7623,6 +7625,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
...
@@ -7623,6 +7625,7 @@ void CudaIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrat
// We can just restore the forces we saved earlier.
// We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
savedForces[forceGroupFlags[step]]->copyTo(cu.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
}
}
else {
else {
recordChangedParameters(context);
recordChangedParameters(context);
...
...
platforms/opencl/src/OpenCLKernels.cpp
View file @
897fb788
...
@@ -7944,13 +7944,15 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
...
@@ -7944,13 +7944,15 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
int nextStep = step+1;
int nextStep = step+1;
int lastForceGroups = context.getLastForceGroups();
int lastForceGroups = context.getLastForceGroups();
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if ((needsForces[step] || needsEnergy[step]) && (!forcesAreValid || lastForceGroups != forceGroupFlags[step])) {
if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
if (forcesAreValid) {
if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
// The forces are still valid. We just need a different force group right now. Save the old
// The forces are still valid. We just need a different force group right now. Save the old
// forces in case we need them again.
// forces in case we need them again.
cl.getForce().copyTo(*savedForces[lastForceGroups]);
cl.getForce().copyTo(*savedForces[lastForceGroups]);
validSavedForces.insert(lastForceGroups);
validSavedForces.insert(lastForceGroups);
}
}
}
else
else
validSavedForces.clear();
validSavedForces.clear();
...
@@ -7963,6 +7965,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
...
@@ -7963,6 +7965,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
// We can just restore the forces we saved earlier.
// We can just restore the forces we saved earlier.
savedForces[forceGroupFlags[step]]->copyTo(cl.getForce());
savedForces[forceGroupFlags[step]]->copyTo(cl.getForce());
context.getLastForceGroups() = forceGroupFlags[step];
}
}
else {
else {
recordChangedParameters(context);
recordChangedParameters(context);
...
...
tests/TestCustomIntegrator.h
View file @
897fb788
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
#include "openmm/AndersenThermostat.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomExternalForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/NonbondedForce.h"
...
@@ -921,6 +922,52 @@ void testTabulatedFunction() {
...
@@ -921,6 +922,52 @@ void testTabulatedFunction() {
ASSERT_EQUAL_VEC
(
Vec3
(
12.0
,
13.0
,
14.0
),
values
[
0
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
12.0
,
13.0
,
14.0
),
values
[
0
],
1e-5
);
}
}
/**
* Test an integrator that alternates repeatedly between force groups.
*/
void
testAlternatingGroups
()
{
System
system
;
system
.
addParticle
(
1.0
);
CustomExternalForce
*
force1
=
new
CustomExternalForce
(
"-0.5*x"
);
force1
->
addParticle
(
0
);
system
.
addForce
(
force1
);
CustomExternalForce
*
force2
=
new
CustomExternalForce
(
"-0.8*y"
);
force2
->
addParticle
(
0
);
force2
->
setForceGroup
(
1
);
system
.
addForce
(
force2
);
CustomIntegrator
integrator
(
0.5
);
integrator
.
addGlobalVariable
(
"savede1"
,
0.0
);
integrator
.
addGlobalVariable
(
"savede2"
,
0.0
);
integrator
.
addGlobalVariable
(
"savede3"
,
0.0
);
integrator
.
addGlobalVariable
(
"savede4"
,
0.0
);
integrator
.
addPerDofVariable
(
"savedf1"
,
0.0
);
integrator
.
addPerDofVariable
(
"savedf2"
,
0.0
);
integrator
.
addPerDofVariable
(
"savedf3"
,
0.0
);
integrator
.
addPerDofVariable
(
"savedf4"
,
0.0
);
integrator
.
addComputeGlobal
(
"savede1"
,
"energy0"
);
integrator
.
addComputeGlobal
(
"savede2"
,
"energy1"
);
integrator
.
addComputePerDof
(
"savedf1"
,
"f0"
);
integrator
.
addComputePerDof
(
"savedf2"
,
"f1"
);
integrator
.
addComputeGlobal
(
"savede3"
,
"energy0"
);
integrator
.
addComputeGlobal
(
"savede4"
,
"energy1"
);
integrator
.
addComputePerDof
(
"savedf3"
,
"f0"
);
integrator
.
addComputePerDof
(
"savedf4"
,
"f1"
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
1
);
positions
[
0
]
=
Vec3
(
1
,
2
,
3
);
context
.
setPositions
(
positions
);
integrator
.
step
(
1
);
vector
<
Vec3
>
f
;
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
ASSERT_EQUAL_TOL
(
-
0.5
*
1
,
integrator
.
getGlobalVariable
(
2
*
i
),
1e-5
);
ASSERT_EQUAL_TOL
(
-
0.8
*
2
,
integrator
.
getGlobalVariable
(
2
*
i
+
1
),
1e-5
);
integrator
.
getPerDofVariable
(
2
*
i
,
f
);
ASSERT_EQUAL_VEC
(
Vec3
(
0.5
,
0
,
0
),
f
[
0
],
1e-5
);
integrator
.
getPerDofVariable
(
2
*
i
+
1
,
f
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0.8
,
0
),
f
[
0
],
1e-5
);
}
}
void
runPlatformTests
();
void
runPlatformTests
();
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
...
@@ -944,6 +991,7 @@ int main(int argc, char* argv[]) {
...
@@ -944,6 +991,7 @@ int main(int argc, char* argv[]) {
testEnergyParameterDerivatives
();
testEnergyParameterDerivatives
();
testChangeDT
();
testChangeDT
();
testTabulatedFunction
();
testTabulatedFunction
();
testAlternatingGroups
();
runPlatformTests
();
runPlatformTests
();
}
}
catch
(
const
exception
&
e
)
{
catch
(
const
exception
&
e
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment