Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
eaef52d9
Commit
eaef52d9
authored
Sep 18, 2015
by
peastman
Browse files
Created OpenCL implementation of periodicdistance()
parent
91a8cc49
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
282 additions
and
173 deletions
+282
-173
platforms/opencl/include/OpenCLExpressionUtilities.h
platforms/opencl/include/OpenCLExpressionUtilities.h
+7
-3
platforms/opencl/src/OpenCLBondedUtilities.cpp
platforms/opencl/src/OpenCLBondedUtilities.cpp
+18
-3
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+211
-165
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+3
-1
platforms/opencl/tests/TestOpenCLCustomExternalForce.cpp
platforms/opencl/tests/TestOpenCLCustomExternalForce.cpp
+43
-1
No files found.
platforms/opencl/include/OpenCLExpressionUtilities.h
View file @
eaef52d9
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2009-201
4
Stanford University and the Authors. *
* Portions copyright (c) 2009-201
5
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -89,6 +89,10 @@ public:
...
@@ -89,6 +89,10 @@ public:
* @param function the function for which to get a placeholder
* @param function the function for which to get a placeholder
*/
*/
Lepton
::
CustomFunction
*
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
);
Lepton
::
CustomFunction
*
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
);
/**
* Get a Lepton::CustomFunction that can be used to represent the periodicdistance() function when parsing expressions.
*/
Lepton
::
CustomFunction
*
getPeriodicDistancePlaceholder
();
private:
private:
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
public:
...
@@ -114,13 +118,13 @@ private:
...
@@ -114,13 +118,13 @@ private:
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
string
&
prefix
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
const
std
::
string
&
prefix
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
void
findRelated
Tabulated
Functions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelated
Custom
Functions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
);
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
);
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
OpenCLContext
&
context
;
OpenCLContext
&
context
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
,
periodicDistance
;
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
platforms/opencl/src/OpenCLBondedUtilities.cpp
View file @
eaef52d9
...
@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
...
@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
for
(
int
i
=
0
;
i
<
(
int
)
prefixCode
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
prefixCode
.
size
();
i
++
)
s
<<
prefixCode
[
i
];
s
<<
prefixCode
[
i
];
string
bufferType
=
(
context
.
getSupports64BitGlobalAtomics
()
?
"long"
:
"real4"
);
string
bufferType
=
(
context
.
getSupports64BitGlobalAtomics
()
?
"long"
:
"real4"
);
s
<<
"__kernel void computeBondedForces(__global "
<<
bufferType
<<
"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups"
;
s
<<
"__kernel void computeBondedForces(__global "
<<
bufferType
<<
"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups
, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
"
;
for
(
int
i
=
0
;
i
<
setSize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
setSize
;
i
++
)
{
int
force
=
set
[
i
];
int
force
=
set
[
i
];
string
indexType
=
"uint"
+
(
indexWidth
[
force
]
==
1
?
""
:
context
.
intToString
(
indexWidth
[
force
]));
string
indexType
=
"uint"
+
(
indexWidth
[
force
]
==
1
?
""
:
context
.
intToString
(
indexWidth
[
force
]));
...
@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
...
@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getEnergyBuffer
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getEnergyBuffer
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
context
.
getPosq
().
getDeviceBuffer
());
index
++
;
index
+=
6
;
for
(
int
j
=
0
;
j
<
(
int
)
forceSets
[
i
].
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
forceSets
[
i
].
size
();
j
++
)
{
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
atomIndices
[
forceSets
[
i
][
j
]]
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
atomIndices
[
forceSets
[
i
][
j
]]
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
bufferIndices
[
forceSets
[
i
][
j
]]
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
index
++
,
bufferIndices
[
forceSets
[
i
][
j
]]
->
getDeviceBuffer
());
...
@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
...
@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
}
}
}
}
for
(
int
i
=
0
;
i
<
(
int
)
kernels
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
(
int
)
kernels
.
size
();
i
++
)
{
kernels
[
i
].
setArg
<
cl_int
>
(
3
,
groups
);
cl
::
Kernel
&
kernel
=
kernels
[
i
];
kernel
.
setArg
<
cl_int
>
(
3
,
groups
);
if
(
context
.
getUseDoublePrecision
())
{
kernel
.
setArg
<
mm_double4
>
(
4
,
context
.
getPeriodicBoxSizeDouble
());
kernel
.
setArg
<
mm_double4
>
(
5
,
context
.
getInvPeriodicBoxSizeDouble
());
kernel
.
setArg
<
mm_double4
>
(
6
,
context
.
getPeriodicBoxVecXDouble
());
kernel
.
setArg
<
mm_double4
>
(
7
,
context
.
getPeriodicBoxVecYDouble
());
kernel
.
setArg
<
mm_double4
>
(
8
,
context
.
getPeriodicBoxVecZDouble
());
}
else
{
kernel
.
setArg
<
mm_float4
>
(
4
,
context
.
getPeriodicBoxSize
());
kernel
.
setArg
<
mm_float4
>
(
5
,
context
.
getInvPeriodicBoxSize
());
kernel
.
setArg
<
mm_float4
>
(
6
,
context
.
getPeriodicBoxVecX
());
kernel
.
setArg
<
mm_float4
>
(
7
,
context
.
getPeriodicBoxVecY
());
kernel
.
setArg
<
mm_float4
>
(
8
,
context
.
getPeriodicBoxVecZ
());
}
context
.
executeKernel
(
kernels
[
i
],
maxBonds
);
context
.
executeKernel
(
kernels
[
i
],
maxBonds
);
}
}
}
}
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
eaef52d9
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2009-201
4
Stanford University and the Authors. *
* Portions copyright (c) 2009-201
5
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -33,7 +33,7 @@ using namespace OpenMM;
...
@@ -33,7 +33,7 @@ using namespace OpenMM;
using
namespace
Lepton
;
using
namespace
Lepton
;
using
namespace
std
;
using
namespace
std
;
OpenCLExpressionUtilities
::
OpenCLExpressionUtilities
(
OpenCLContext
&
context
)
:
context
(
context
),
fp1
(
1
),
fp2
(
2
),
fp3
(
3
)
{
OpenCLExpressionUtilities
::
OpenCLExpressionUtilities
(
OpenCLContext
&
context
)
:
context
(
context
),
fp1
(
1
),
fp2
(
2
),
fp3
(
3
)
,
periodicDistance
(
6
)
{
}
}
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
...
@@ -79,11 +79,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -79,11 +79,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
throw
OpenMMException
(
"Unknown variable in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown variable in expression: "
+
node
.
getOperation
().
getName
());
case
Operation
::
CUSTOM
:
case
Operation
::
CUSTOM
:
{
{
int
i
;
for
(
i
=
0
;
i
<
(
int
)
functionNames
.
size
()
&&
functionNames
[
i
].
first
!=
node
.
getOperation
().
getName
();
i
++
)
;
if
(
i
==
functionNames
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
out
<<
"0.0f;
\n
"
;
out
<<
"0.0f;
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
temps
.
push_back
(
make_pair
(
node
,
name
));
hasRecordedNode
=
true
;
hasRecordedNode
=
true
;
...
@@ -93,7 +88,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -93,7 +88,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
vector
<
const
ExpressionTreeNode
*>
nodes
;
vector
<
const
ExpressionTreeNode
*>
nodes
;
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
findRelated
Tabulated
Functions
(
node
,
allExpressions
[
j
].
getRootNode
(),
nodes
);
findRelated
Custom
Functions
(
node
,
allExpressions
[
j
].
getRootNode
(),
nodes
);
vector
<
string
>
nodeNames
;
vector
<
string
>
nodeNames
;
nodeNames
.
push_back
(
name
);
nodeNames
.
push_back
(
name
);
for
(
int
j
=
1
;
j
<
(
int
)
nodes
.
size
();
j
++
)
{
for
(
int
j
=
1
;
j
<
(
int
)
nodes
.
size
();
j
++
)
{
...
@@ -103,6 +98,52 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -103,6 +98,52 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
temps
.
push_back
(
make_pair
(
*
nodes
[
j
],
name2
));
temps
.
push_back
(
make_pair
(
*
nodes
[
j
],
name2
));
}
}
out
<<
"{
\n
"
;
out
<<
"{
\n
"
;
if
(
node
.
getOperation
().
getName
()
==
"periodicdistance"
)
{
// This is the periodicdistance() function.
out
<<
tempType
<<
"3 periodicDistance_delta = (real3) ("
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
if
(
i
>
0
)
out
<<
", "
;
out
<<
getTempName
(
node
.
getChildren
()[
i
],
temps
)
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
i
+
3
],
temps
);
}
out
<<
");
\n
"
;
out
<<
"APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)
\n
"
;
out
<<
tempType
<<
" periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);
\n
"
;
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
int
argIndex
=
-
1
;
for
(
int
k
=
0
;
k
<
6
;
k
++
)
{
if
(
derivOrder
[
k
]
>
0
)
{
if
(
derivOrder
[
k
]
>
1
||
argIndex
!=
-
1
)
throw
OpenMMException
(
"Unsupported derivative of periodicdistance"
);
// Should be impossible for this to happen.
argIndex
=
k
;
}
}
if
(
argIndex
==
-
1
)
out
<<
nodeNames
[
j
]
<<
" = RECIP(periodicDistance_rinv);
\n
"
;
else
if
(
argIndex
==
0
)
out
<<
nodeNames
[
j
]
<<
" = periodicDistance_delta.x*periodicDistance_rinv;
\n
"
;
else
if
(
argIndex
==
1
)
out
<<
nodeNames
[
j
]
<<
" = periodicDistance_delta.y*periodicDistance_rinv;
\n
"
;
else
if
(
argIndex
==
2
)
out
<<
nodeNames
[
j
]
<<
" = periodicDistance_delta.z*periodicDistance_rinv;
\n
"
;
else
if
(
argIndex
==
3
)
out
<<
nodeNames
[
j
]
<<
" = -periodicDistance_delta.x*periodicDistance_rinv;
\n
"
;
else
if
(
argIndex
==
4
)
out
<<
nodeNames
[
j
]
<<
" = -periodicDistance_delta.y*periodicDistance_rinv;
\n
"
;
else
if
(
argIndex
==
5
)
out
<<
nodeNames
[
j
]
<<
" = -periodicDistance_delta.z*periodicDistance_rinv;
\n
"
;
}
}
else
{
// This is a tabulated function.
int
i
;
for
(
i
=
0
;
i
<
(
int
)
functionNames
.
size
()
&&
functionNames
[
i
].
first
!=
node
.
getOperation
().
getName
();
i
++
)
;
if
(
i
==
functionNames
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
vector
<
string
>
paramsFloat
,
paramsInt
;
vector
<
string
>
paramsFloat
,
paramsInt
;
for
(
int
j
=
0
;
j
<
(
int
)
functionParams
[
i
].
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
functionParams
[
i
].
size
();
j
++
)
{
paramsFloat
.
push_back
(
context
.
doubleToString
(
functionParams
[
i
][
j
]));
paramsFloat
.
push_back
(
context
.
doubleToString
(
functionParams
[
i
][
j
]));
...
@@ -275,6 +316,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -275,6 +316,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
}
}
}
}
}
}
}
out
<<
"}"
;
out
<<
"}"
;
break
;
break
;
}
}
...
@@ -475,7 +517,7 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
...
@@ -475,7 +517,7 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
throw
OpenMMException
(
out
.
str
());
throw
OpenMMException
(
out
.
str
());
}
}
void
OpenCLExpressionUtilities
::
findRelated
Tabulated
Functions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
void
OpenCLExpressionUtilities
::
findRelated
Custom
Functions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
)
{
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
)
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getOperation
().
getName
()
==
searchNode
.
getOperation
().
getName
())
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getOperation
().
getName
()
==
searchNode
.
getOperation
().
getName
())
{
// Make sure the arguments are identical.
// Make sure the arguments are identical.
...
@@ -496,7 +538,7 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr
...
@@ -496,7 +538,7 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr
}
}
else
else
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
findRelated
Tabulated
Functions
(
node
,
searchNode
.
getChildren
()[
i
],
nodes
);
findRelated
Custom
Functions
(
node
,
searchNode
.
getChildren
()[
i
],
nodes
);
}
}
void
OpenCLExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
void
OpenCLExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
...
@@ -722,3 +764,7 @@ Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const
...
@@ -722,3 +764,7 @@ Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const
return
&
fp3
;
return
&
fp3
;
throw
OpenMMException
(
"getFunctionPlaceholder: Unknown function type"
);
throw
OpenMMException
(
"getFunctionPlaceholder: Unknown function type"
);
}
}
Lepton
::
CustomFunction
*
OpenCLExpressionUtilities
::
getPeriodicDistancePlaceholder
()
{
return
&
periodicDistance
;
}
platforms/opencl/src/OpenCLKernels.cpp
View file @
eaef52d9
...
@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
...
@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
}
}
Lepton
::
ParsedExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
()).
optimize
();
map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cl.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
...
...
platforms/opencl/tests/TestOpenCLCustomExternalForce.cpp
View file @
eaef52d9
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* *
* Portions copyright (c) 2008-20
09
Stanford University and the Authors. *
* Portions copyright (c) 2008-20
15
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: *
* *
* *
...
@@ -161,6 +161,47 @@ void testParallelComputation() {
...
@@ -161,6 +161,47 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC
(
state1
.
getForces
()[
i
],
state2
.
getForces
()[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
state1
.
getForces
()[
i
],
state2
.
getForces
()[
i
],
1e-5
);
}
}
void
testPeriodic
()
{
Vec3
vx
(
5
,
0
,
0
);
Vec3
vy
(
0
,
6
,
0
);
Vec3
vz
(
1
,
2
,
7
);
double
x0
=
51
,
y0
=
-
17
,
z0
=
11.2
;
System
system
;
system
.
setDefaultPeriodicBoxVectors
(
vx
,
vy
,
vz
);
system
.
addParticle
(
1.0
);
CustomExternalForce
*
force
=
new
CustomExternalForce
(
"periodicdistance(x, y, z, x0, y0, z0)^2"
);
force
->
addPerParticleParameter
(
"x0"
);
force
->
addPerParticleParameter
(
"y0"
);
force
->
addPerParticleParameter
(
"z0"
);
vector
<
double
>
params
(
3
);
params
[
0
]
=
x0
;
params
[
1
]
=
y0
;
params
[
2
]
=
z0
;
force
->
addParticle
(
0
,
params
);
system
.
addForce
(
force
);
VerletIntegrator
integrator
(
0.01
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
1
);
positions
[
0
]
=
Vec3
(
0
,
2
,
0
);
context
.
setPositions
(
positions
);
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
State
state
=
context
.
getState
(
State
::
Positions
|
State
::
Forces
|
State
::
Energy
);
// Apply periodic boundary conditions to the difference between the two positions.
Vec3
delta
=
Vec3
(
x0
,
y0
,
z0
)
-
state
.
getPositions
()[
0
];
delta
-=
vz
*
floor
(
delta
[
2
]
/
vz
[
2
]
+
0.5
);
delta
-=
vy
*
floor
(
delta
[
1
]
/
vy
[
1
]
+
0.5
);
delta
-=
vx
*
floor
(
delta
[
0
]
/
vx
[
0
]
+
0.5
);
// Verify that the force and energy are correct.
ASSERT_EQUAL_VEC
(
delta
*
2
,
state
.
getForces
()[
0
],
1e-5
);
ASSERT_EQUAL_TOL
(
delta
.
dot
(
delta
),
state
.
getPotentialEnergy
(),
1e-5
);
integrator
.
step
(
1
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
try
{
try
{
if
(
argc
>
1
)
if
(
argc
>
1
)
...
@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
...
@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
testForce
();
testForce
();
testManyParameters
();
testManyParameters
();
testParallelComputation
();
testParallelComputation
();
testPeriodic
();
}
}
catch
(
const
exception
&
e
)
{
catch
(
const
exception
&
e
)
{
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment