Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
c3c8ec55
"...reference/src/SimTKReference/ReferenceCCMAAlgorithm.cpp" did not exist on "671419fac9c9f62999d1e71cc98b0697c67b68ac"
Unverified
Commit
c3c8ec55
authored
May 19, 2022
by
Peter Eastman
Committed by
GitHub
May 19, 2022
Browse files
Vectorized calculating long range correction coefficient (#3606)
parent
b096cd7c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
47 deletions
+84
-47
openmmapi/include/openmm/internal/CustomNonbondedForceImpl.h
openmmapi/include/openmm/internal/CustomNonbondedForceImpl.h
+6
-4
openmmapi/src/CustomNonbondedForceImpl.cpp
openmmapi/src/CustomNonbondedForceImpl.cpp
+72
-37
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+2
-2
platforms/cpu/src/CpuKernels.cpp
platforms/cpu/src/CpuKernels.cpp
+2
-2
platforms/reference/src/ReferenceKernels.cpp
platforms/reference/src/ReferenceKernels.cpp
+2
-2
No files found.
openmmapi/include/openmm/internal/CustomNonbondedForceImpl.h
View file @
c3c8ec55
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
#include "openmm/Kernel.h"
#include "openmm/Kernel.h"
#include "openmm/internal/ThreadPool.h"
#include "openmm/internal/ThreadPool.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include <utility>
#include <utility>
#include <map>
#include <map>
#include <string>
#include <string>
...
@@ -69,7 +70,7 @@ public:
...
@@ -69,7 +70,7 @@ public:
* the Context (such as global parameters). This allows the coefficient to be updated
* the Context (such as global parameters). This allows the coefficient to be updated
* more quickly when global parameters change.
* more quickly when global parameters change.
*/
*/
static
LongRangeCorrectionData
prepareLongRangeCorrection
(
const
CustomNonbondedForce
&
force
);
static
LongRangeCorrectionData
prepareLongRangeCorrection
(
const
CustomNonbondedForce
&
force
,
int
numThreads
);
/**
/**
* Compute the coefficient which, when divided by the periodic box volume, gives the
* Compute the coefficient which, when divided by the periodic box volume, gives the
* long range correction to the energy. If the Force computes parameter derivatives,
* long range correction to the energy. If the Force computes parameter derivatives,
...
@@ -77,7 +78,7 @@ public:
...
@@ -77,7 +78,7 @@ public:
*/
*/
static
void
calcLongRangeCorrection
(
const
CustomNonbondedForce
&
force
,
LongRangeCorrectionData
&
data
,
const
Context
&
context
,
double
&
coefficient
,
std
::
vector
<
double
>&
derivatives
,
ThreadPool
&
threads
);
static
void
calcLongRangeCorrection
(
const
CustomNonbondedForce
&
force
,
LongRangeCorrectionData
&
data
,
const
Context
&
context
,
double
&
coefficient
,
std
::
vector
<
double
>&
derivatives
,
ThreadPool
&
threads
);
private:
private:
static
double
integrateInteraction
(
Lepton
::
CompiledExpression
&
expression
,
const
std
::
vector
<
double
>&
params1
,
const
std
::
vector
<
double
>&
params2
,
static
double
integrateInteraction
(
Lepton
::
Compiled
Vector
Expression
&
expression
,
const
std
::
vector
<
double
>&
params1
,
const
std
::
vector
<
double
>&
params2
,
const
std
::
vector
<
double
>&
computedValues1
,
const
std
::
vector
<
double
>&
computedValues2
,
const
CustomNonbondedForce
&
force
,
const
Context
&
context
,
const
std
::
vector
<
double
>&
computedValues1
,
const
std
::
vector
<
double
>&
computedValues2
,
const
CustomNonbondedForce
&
force
,
const
Context
&
context
,
const
std
::
vector
<
std
::
string
>&
paramNames
,
const
std
::
vector
<
std
::
string
>&
computedValueNames
);
const
std
::
vector
<
std
::
string
>&
paramNames
,
const
std
::
vector
<
std
::
string
>&
computedValueNames
);
const
CustomNonbondedForce
&
owner
;
const
CustomNonbondedForce
&
owner
;
...
@@ -90,8 +91,9 @@ public:
...
@@ -90,8 +91,9 @@ public:
std
::
vector
<
std
::
vector
<
double
>
>
classes
;
std
::
vector
<
std
::
vector
<
double
>
>
classes
;
std
::
vector
<
std
::
string
>
paramNames
,
computedValueNames
;
std
::
vector
<
std
::
string
>
paramNames
,
computedValueNames
;
std
::
map
<
std
::
pair
<
int
,
int
>
,
long
long
int
>
interactionCount
;
std
::
map
<
std
::
pair
<
int
,
int
>
,
long
long
int
>
interactionCount
;
Lepton
::
CompiledExpression
energyExpression
;
std
::
vector
<
Lepton
::
CompiledVectorExpression
>
energyExpression
;
std
::
vector
<
Lepton
::
CompiledExpression
>
derivExpressions
,
computedValueExpressions
;
std
::
vector
<
std
::
vector
<
Lepton
::
CompiledVectorExpression
>
>
derivExpressions
;
std
::
vector
<
Lepton
::
CompiledExpression
>
computedValueExpressions
;
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
openmmapi/src/CustomNonbondedForceImpl.cpp
View file @
c3c8ec55
...
@@ -162,7 +162,7 @@ void CustomNonbondedForceImpl::updateParametersInContext(ContextImpl& context) {
...
@@ -162,7 +162,7 @@ void CustomNonbondedForceImpl::updateParametersInContext(ContextImpl& context) {
context
.
systemChanged
();
context
.
systemChanged
();
}
}
CustomNonbondedForceImpl
::
LongRangeCorrectionData
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
const
CustomNonbondedForce
&
force
)
{
CustomNonbondedForceImpl
::
LongRangeCorrectionData
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
const
CustomNonbondedForce
&
force
,
int
numThreads
)
{
LongRangeCorrectionData
data
;
LongRangeCorrectionData
data
;
data
.
method
=
force
.
getNonbondedMethod
();
data
.
method
=
force
.
getNonbondedMethod
();
if
(
data
.
method
==
CustomNonbondedForce
::
NoCutoff
||
data
.
method
==
CustomNonbondedForce
::
CutoffNonPeriodic
)
if
(
data
.
method
==
CustomNonbondedForce
::
NoCutoff
||
data
.
method
==
CustomNonbondedForce
::
CutoffNonPeriodic
)
...
@@ -227,12 +227,19 @@ CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prep
...
@@ -227,12 +227,19 @@ CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prep
// Prepare for evaluating the expressions.
// Prepare for evaluating the expressions.
int
width
=
Lepton
::
CompiledVectorExpression
::
getAllowedWidths
().
back
();
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
functions
[
force
.
getTabulatedFunctionName
(
i
)]
=
createReferenceTabulatedFunction
(
force
.
getTabulatedFunction
(
i
));
data
.
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
).
createCompiledExpression
();
Lepton
::
CompiledVectorExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
).
createCompiledVectorExpression
(
width
);
for
(
int
k
=
0
;
k
<
force
.
getNumEnergyParameterDerivatives
();
k
++
)
for
(
int
i
=
0
;
i
<
numThreads
;
i
++
)
data
.
derivExpressions
.
push_back
(
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
).
differentiate
(
force
.
getEnergyParameterDerivativeName
(
k
)).
createCompiledExpression
());
data
.
energyExpression
.
push_back
(
energyExpression
);
data
.
derivExpressions
.
resize
(
numThreads
);
for
(
int
k
=
0
;
k
<
force
.
getNumEnergyParameterDerivatives
();
k
++
)
{
Lepton
::
CompiledVectorExpression
derivExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
(),
functions
).
differentiate
(
force
.
getEnergyParameterDerivativeName
(
k
)).
createCompiledVectorExpression
(
width
);
for
(
int
i
=
0
;
i
<
numThreads
;
i
++
)
data
.
derivExpressions
[
i
].
push_back
(
derivExpression
);
}
for
(
int
i
=
0
;
i
<
force
.
getNumPerParticleParameters
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumPerParticleParameters
();
i
++
)
{
string
name
=
force
.
getPerParticleParameterName
(
i
);
string
name
=
force
.
getPerParticleParameterName
(
i
);
data
.
paramNames
.
push_back
(
name
+
"1"
);
data
.
paramNames
.
push_back
(
name
+
"1"
);
...
@@ -283,7 +290,7 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
...
@@ -283,7 +290,7 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
vector
<
double
>
threadSum
(
threads
.
getNumThreads
(),
0.0
);
vector
<
double
>
threadSum
(
threads
.
getNumThreads
(),
0.0
);
atomic
<
int
>
atomicCounter
(
0
);
atomic
<
int
>
atomicCounter
(
0
);
threads
.
execute
([
&
]
(
ThreadPool
&
threads
,
int
threadIndex
)
{
threads
.
execute
([
&
]
(
ThreadPool
&
threads
,
int
threadIndex
)
{
Lepton
::
CompiledExpression
expression
=
data
.
energyExpression
;
Lepton
::
Compiled
Vector
Expression
&
expression
=
data
.
energyExpression
[
threadIndex
]
;
while
(
true
)
{
while
(
true
)
{
int
i
=
atomicCounter
++
;
int
i
=
atomicCounter
++
;
if
(
i
>=
numClasses
)
if
(
i
>=
numClasses
)
...
@@ -302,13 +309,13 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
...
@@ -302,13 +309,13 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
// Now do the same for parameter derivatives.
// Now do the same for parameter derivatives.
int
numDerivs
=
data
.
derivExpressions
.
size
();
int
numDerivs
=
data
.
derivExpressions
[
0
]
.
size
();
derivatives
.
resize
(
numDerivs
);
derivatives
.
resize
(
numDerivs
);
for
(
int
k
=
0
;
k
<
numDerivs
;
k
++
)
{
for
(
int
k
=
0
;
k
<
numDerivs
;
k
++
)
{
atomicCounter
=
0
;
atomicCounter
=
0
;
threads
.
execute
([
&
]
(
ThreadPool
&
threads
,
int
threadIndex
)
{
threads
.
execute
([
&
]
(
ThreadPool
&
threads
,
int
threadIndex
)
{
threadSum
[
threadIndex
]
=
0
;
threadSum
[
threadIndex
]
=
0
;
Lepton
::
CompiledExpression
expression
=
data
.
derivExpressions
[
k
];
Lepton
::
Compiled
Vector
Expression
&
expression
=
data
.
derivExpressions
[
threadIndex
][
k
];
while
(
true
)
{
while
(
true
)
{
int
i
=
atomicCounter
++
;
int
i
=
atomicCounter
++
;
if
(
i
>=
numClasses
)
if
(
i
>=
numClasses
)
...
@@ -327,35 +334,51 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
...
@@ -327,35 +334,51 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
}
}
}
}
double
CustomNonbondedForceImpl
::
integrateInteraction
(
Lepton
::
CompiledExpression
&
expression
,
const
vector
<
double
>&
params1
,
const
vector
<
double
>&
params2
,
double
CustomNonbondedForceImpl
::
integrateInteraction
(
Lepton
::
Compiled
Vector
Expression
&
expression
,
const
vector
<
double
>&
params1
,
const
vector
<
double
>&
params2
,
const
vector
<
double
>&
computedValues1
,
const
vector
<
double
>&
computedValues2
,
const
CustomNonbondedForce
&
force
,
const
Context
&
context
,
const
vector
<
double
>&
computedValues1
,
const
vector
<
double
>&
computedValues2
,
const
CustomNonbondedForce
&
force
,
const
Context
&
context
,
const
vector
<
string
>&
paramNames
,
const
vector
<
string
>&
computedValueNames
)
{
const
vector
<
string
>&
paramNames
,
const
vector
<
string
>&
computedValueNames
)
{
int
width
=
expression
.
getWidth
();
const
set
<
string
>&
variables
=
expression
.
getVariables
();
const
set
<
string
>&
variables
=
expression
.
getVariables
();
for
(
int
i
=
0
;
i
<
force
.
getNumPerParticleParameters
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumPerParticleParameters
();
i
++
)
{
if
(
variables
.
find
(
paramNames
[
2
*
i
])
!=
variables
.
end
())
if
(
variables
.
find
(
paramNames
[
2
*
i
])
!=
variables
.
end
())
{
expression
.
getVariableReference
(
paramNames
[
2
*
i
])
=
params1
[
i
];
float
*
pointer
=
expression
.
getVariablePointer
(
paramNames
[
2
*
i
]);
if
(
variables
.
find
(
paramNames
[
2
*
i
+
1
])
!=
variables
.
end
())
for
(
int
j
=
0
;
j
<
width
;
j
++
)
expression
.
getVariableReference
(
paramNames
[
2
*
i
+
1
])
=
params2
[
i
];
pointer
[
j
]
=
params1
[
i
];
}
if
(
variables
.
find
(
paramNames
[
2
*
i
+
1
])
!=
variables
.
end
())
{
float
*
pointer
=
expression
.
getVariablePointer
(
paramNames
[
2
*
i
+
1
]);
for
(
int
j
=
0
;
j
<
width
;
j
++
)
pointer
[
j
]
=
params2
[
i
];
}
}
}
for
(
int
i
=
0
;
i
<
force
.
getNumComputedValues
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumComputedValues
();
i
++
)
{
if
(
variables
.
find
(
computedValueNames
[
2
*
i
])
!=
variables
.
end
())
if
(
variables
.
find
(
computedValueNames
[
2
*
i
])
!=
variables
.
end
())
{
expression
.
getVariableReference
(
computedValueNames
[
2
*
i
])
=
computedValues1
[
i
];
float
*
pointer
=
expression
.
getVariablePointer
(
computedValueNames
[
2
*
i
]);
if
(
variables
.
find
(
computedValueNames
[
2
*
i
+
1
])
!=
variables
.
end
())
for
(
int
j
=
0
;
j
<
width
;
j
++
)
expression
.
getVariableReference
(
computedValueNames
[
2
*
i
+
1
])
=
computedValues2
[
i
];
pointer
[
j
]
=
computedValues1
[
i
];
}
if
(
variables
.
find
(
computedValueNames
[
2
*
i
+
1
])
!=
variables
.
end
())
{
float
*
pointer
=
expression
.
getVariablePointer
(
computedValueNames
[
2
*
i
+
1
]);
for
(
int
j
=
0
;
j
<
width
;
j
++
)
pointer
[
j
]
=
computedValues2
[
i
];
}
}
}
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumGlobalParameters
();
i
++
)
{
const
string
&
name
=
force
.
getGlobalParameterName
(
i
);
const
string
&
name
=
force
.
getGlobalParameterName
(
i
);
if
(
variables
.
find
(
name
)
!=
variables
.
end
())
if
(
variables
.
find
(
name
)
!=
variables
.
end
())
{
expression
.
getVariableReference
(
name
)
=
context
.
getParameter
(
name
);
float
*
pointer
=
expression
.
getVariablePointer
(
name
);
for
(
int
j
=
0
;
j
<
width
;
j
++
)
pointer
[
j
]
=
context
.
getParameter
(
name
);
}
}
}
// To integrate from r_cutoff to infinity, make the change of variables x=r_cutoff/r and integrate from 0 to 1.
// To integrate from r_cutoff to infinity, make the change of variables x=r_cutoff/r and integrate from 0 to 1.
// This introduces another r^2 into the integral, which along with the r^2 in the formula for the correction
// This introduces another r^2 into the integral, which along with the r^2 in the formula for the correction
// means we multiply the function by r^4. Use the midpoint method.
// means we multiply the function by r^4. Use the midpoint method.
double
*
rPointe
r
;
float
*
r
;
try
{
try
{
r
Pointer
=
&
expression
.
getVariable
Reference
(
"r"
);
r
=
expression
.
getVariable
Pointer
(
"r"
);
}
}
catch
(
exception
&
ex
)
{
catch
(
exception
&
ex
)
{
throw
OpenMMException
(
"CustomNonbondedForce: Cannot use long range correction with a force that does not depend on r."
);
throw
OpenMMException
(
"CustomNonbondedForce: Cannot use long range correction with a force that does not depend on r."
);
...
@@ -366,14 +389,20 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
...
@@ -366,14 +389,20 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
for
(
int
iteration
=
0
;
;
iteration
++
)
{
for
(
int
iteration
=
0
;
;
iteration
++
)
{
double
oldSum
=
sum
;
double
oldSum
=
sum
;
double
newSum
=
0
;
double
newSum
=
0
;
int
element
=
0
;
for
(
int
i
=
0
;
i
<
numPoints
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numPoints
;
i
++
)
{
if
(
i
%
3
==
1
)
if
(
i
%
3
!=
1
)
{
continue
;
double
x
=
(
i
+
0.5
)
/
numPoints
;
double
x
=
(
i
+
0.5
)
/
numPoints
;
r
[
element
++
]
=
cutoff
/
x
;
double
r
=
cutoff
/
x
;
if
(
element
==
width
||
i
==
numPoints
-
1
)
{
*
rPointer
=
r
;
const
float
*
result
=
expression
.
evaluate
();
double
r2
=
r
*
r
;
for
(
int
j
=
0
;
j
<
element
;
j
++
)
{
newSum
+=
expression
.
evaluate
()
*
r2
*
r2
;
float
r2
=
r
[
j
]
*
r
[
j
];
newSum
+=
result
[
j
]
*
r2
*
r2
;
}
element
=
0
;
}
}
}
}
sum
=
newSum
/
numPoints
+
oldSum
/
3
;
sum
=
newSum
/
numPoints
+
oldSum
/
3
;
double
relativeChange
=
fabs
((
sum
-
oldSum
)
/
sum
);
double
relativeChange
=
fabs
((
sum
-
oldSum
)
/
sum
);
...
@@ -383,25 +412,31 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
...
@@ -383,25 +412,31 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
throw
OpenMMException
(
"CustomNonbondedForce: Long range correction did not converge. Does the energy go to 0 faster than 1/r^2?"
);
throw
OpenMMException
(
"CustomNonbondedForce: Long range correction did not converge. Does the energy go to 0 faster than 1/r^2?"
);
numPoints
*=
3
;
numPoints
*=
3
;
}
}
// If a switching function is used, integrate over the switching interval.
// If a switching function is used, integrate over the switching interval.
double
sum2
=
0
;
double
sum2
=
0
;
if
(
force
.
getUseSwitchingFunction
())
{
if
(
force
.
getUseSwitchingFunction
())
{
double
rswitch
=
force
.
getSwitchingDistance
();
double
rswitch
=
force
.
getSwitchingDistance
();
sum2
=
0
;
sum2
=
0
;
numPoints
=
1
;
numPoints
=
1
;
vector
<
double
>
switchValue
(
width
);
for
(
int
iteration
=
0
;
;
iteration
++
)
{
for
(
int
iteration
=
0
;
;
iteration
++
)
{
double
oldSum
=
sum2
;
double
oldSum
=
sum2
;
double
newSum
=
0
;
double
newSum
=
0
;
int
element
=
0
;
for
(
int
i
=
0
;
i
<
numPoints
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numPoints
;
i
++
)
{
if
(
i
%
3
==
1
)
if
(
i
%
3
!=
1
)
{
continue
;
double
x
=
(
i
+
0.5
)
/
numPoints
;
double
x
=
(
i
+
0.5
)
/
numPoints
;
switchValue
[
element
]
=
x
*
x
*
x
*
(
10
+
x
*
(
-
15
+
x
*
6
));
double
r
=
rswitch
+
x
*
(
cutoff
-
rswitch
);
r
[
element
++
]
=
rswitch
+
x
*
(
cutoff
-
rswitch
);
double
switchValue
=
x
*
x
*
x
*
(
10
+
x
*
(
-
15
+
x
*
6
));
if
(
element
==
width
||
i
==
numPoints
-
1
)
{
*
rPointer
=
r
;
const
float
*
result
=
expression
.
evaluate
();
newSum
+=
switchValue
*
expression
.
evaluate
()
*
r
*
r
;
for
(
int
j
=
0
;
j
<
element
;
j
++
)
newSum
+=
switchValue
[
j
]
*
result
[
j
]
*
r
[
j
]
*
r
[
j
];
element
=
0
;
}
}
}
}
sum2
=
newSum
/
numPoints
+
oldSum
/
3
;
sum2
=
newSum
/
numPoints
+
oldSum
/
3
;
double
relativeChange
=
fabs
((
sum2
-
oldSum
)
/
sum2
);
double
relativeChange
=
fabs
((
sum2
-
oldSum
)
/
sum2
);
...
...
platforms/common/src/CommonKernels.cpp
View file @
c3c8ec55
...
@@ -2073,7 +2073,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -2073,7 +2073,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
if
(
force
.
getNonbondedMethod
()
==
CustomNonbondedForce
::
CutoffPeriodic
&&
force
.
getUseLongRangeCorrection
()
&&
cc
.
getContextIndex
()
==
0
)
{
if
(
force
.
getNonbondedMethod
()
==
CustomNonbondedForce
::
CutoffPeriodic
&&
force
.
getUseLongRangeCorrection
()
&&
cc
.
getContextIndex
()
==
0
)
{
forceCopy
=
new
CustomNonbondedForce
(
force
);
forceCopy
=
new
CustomNonbondedForce
(
force
);
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
);
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
,
cc
.
getThreadPool
().
getNumThreads
()
);
cc
.
addPostComputation
(
new
LongRangePostComputation
(
cc
,
longRangeCoefficient
,
longRangeCoefficientDerivs
,
forceCopy
));
cc
.
addPostComputation
(
new
LongRangePostComputation
(
cc
,
longRangeCoefficient
,
longRangeCoefficientDerivs
,
forceCopy
));
hasInitializedLongRangeCorrection
=
false
;
hasInitializedLongRangeCorrection
=
false
;
}
}
...
@@ -2449,7 +2449,7 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
...
@@ -2449,7 +2449,7 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
// If necessary, recompute the long range correction.
// If necessary, recompute the long range correction.
if
(
forceCopy
!=
NULL
)
{
if
(
forceCopy
!=
NULL
)
{
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
);
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
,
cc
.
getThreadPool
().
getNumThreads
()
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
cc
.
getThreadPool
());
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
cc
.
getThreadPool
());
hasInitializedLongRangeCorrection
=
false
;
hasInitializedLongRangeCorrection
=
false
;
*
forceCopy
=
force
;
*
forceCopy
=
force
;
...
...
platforms/cpu/src/CpuKernels.cpp
View file @
c3c8ec55
...
@@ -1011,7 +1011,7 @@ double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool inc
...
@@ -1011,7 +1011,7 @@ double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool inc
// Add in the long range correction.
// Add in the long range correction.
if
(
!
hasInitializedLongRangeCorrection
)
{
if
(
!
hasInitializedLongRangeCorrection
)
{
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
*
forceCopy
);
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
*
forceCopy
,
data
.
threads
.
getNumThreads
()
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
*
forceCopy
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
data
.
threads
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
*
forceCopy
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
data
.
threads
);
hasInitializedLongRangeCorrection
=
true
;
hasInitializedLongRangeCorrection
=
true
;
}
}
...
@@ -1042,7 +1042,7 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
...
@@ -1042,7 +1042,7 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
// If necessary, recompute the long range correction.
// If necessary, recompute the long range correction.
if
(
forceCopy
!=
NULL
)
{
if
(
forceCopy
!=
NULL
)
{
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
);
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
,
data
.
threads
.
getNumThreads
()
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
data
.
threads
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
data
.
threads
);
hasInitializedLongRangeCorrection
=
true
;
hasInitializedLongRangeCorrection
=
true
;
*
forceCopy
=
force
;
*
forceCopy
=
force
;
...
...
platforms/reference/src/ReferenceKernels.cpp
View file @
c3c8ec55
...
@@ -1303,8 +1303,8 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
...
@@ -1303,8 +1303,8 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
// Add in the long range correction.
// Add in the long range correction.
if
(
!
hasInitializedLongRangeCorrection
)
{
if
(
!
hasInitializedLongRangeCorrection
)
{
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
*
forceCopy
);
ThreadPool
threads
;
ThreadPool
threads
;
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
*
forceCopy
,
threads
.
getNumThreads
());
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
*
forceCopy
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
threads
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
*
forceCopy
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
threads
);
hasInitializedLongRangeCorrection
=
true
;
hasInitializedLongRangeCorrection
=
true
;
}
}
...
@@ -1337,8 +1337,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
...
@@ -1337,8 +1337,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
// If necessary, recompute the long range correction.
// If necessary, recompute the long range correction.
if
(
forceCopy
!=
NULL
)
{
if
(
forceCopy
!=
NULL
)
{
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
);
ThreadPool
threads
;
ThreadPool
threads
;
longRangeCorrectionData
=
CustomNonbondedForceImpl
::
prepareLongRangeCorrection
(
force
,
threads
.
getNumThreads
());
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
threads
);
CustomNonbondedForceImpl
::
calcLongRangeCorrection
(
force
,
longRangeCorrectionData
,
context
.
getOwner
(),
longRangeCoefficient
,
longRangeCoefficientDerivs
,
threads
);
hasInitializedLongRangeCorrection
=
true
;
hasInitializedLongRangeCorrection
=
true
;
*
forceCopy
=
force
;
*
forceCopy
=
force
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment