Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
a773952e
Commit
a773952e
authored
Jan 24, 2014
by
peastman
Browse files
Created Discrete2DFunction and Discrete3DFunction
parent
56e36449
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
747 additions
and
157 deletions
+747
-157
openmmapi/include/openmm/TabulatedFunction.h
openmmapi/include/openmm/TabulatedFunction.h
+87
-7
openmmapi/src/TabulatedFunction.cpp
openmmapi/src/TabulatedFunction.cpp
+47
-0
platforms/cuda/include/CudaExpressionUtilities.h
platforms/cuda/include/CudaExpressionUtilities.h
+28
-24
platforms/cuda/src/CudaExpressionUtilities.cpp
platforms/cuda/src/CudaExpressionUtilities.cpp
+117
-34
platforms/cuda/src/CudaKernels.cpp
platforms/cuda/src/CudaKernels.cpp
+4
-8
platforms/cuda/tests/TestCudaCustomNonbondedForce.cpp
platforms/cuda/tests/TestCudaCustomNonbondedForce.cpp
+75
-5
platforms/opencl/include/OpenCLExpressionUtilities.h
platforms/opencl/include/OpenCLExpressionUtilities.h
+28
-24
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+117
-34
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+4
-8
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
+74
-5
platforms/reference/include/ReferenceTabulatedFunction.h
platforms/reference/include/ReferenceTabulatedFunction.h
+32
-0
platforms/reference/src/ReferenceTabulatedFunction.cpp
platforms/reference/src/ReferenceTabulatedFunction.cpp
+57
-3
platforms/reference/tests/TestReferenceCustomNonbondedForce.cpp
...rms/reference/tests/TestReferenceCustomNonbondedForce.cpp
+77
-5
No files found.
openmmapi/include/openmm/TabulatedFunction.h
View file @
a773952e
...
@@ -102,35 +102,115 @@ private:
...
@@ -102,35 +102,115 @@ private:
};
};
/**
/**
* This is a TabulatedFunction that computes a discrete one dimensional function.
* This is a TabulatedFunction that computes a discrete one dimensional function f(x).
* To evaluate it, x is rounded to the nearest integer and the table element with that
* index is returned. If the index is outside the range [0, size), the result is undefined.
*/
*/
class
OPENMM_EXPORT
Discrete1DFunction
:
public
TabulatedFunction
{
class
OPENMM_EXPORT
Discrete1DFunction
:
public
TabulatedFunction
{
public:
public:
/**
/**
* Create a Discrete1DFunction f(x) based on a set of tabulated values.
* Create a Discrete1DFunction f(x) based on a set of tabulated values.
*
*
* @param values the tabulated values of the function f(x). The function is only defined
* @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/
*/
Discrete1DFunction
(
const
std
::
vector
<
double
>&
values
);
Discrete1DFunction
(
const
std
::
vector
<
double
>&
values
);
/**
/**
* Get the parameters for the tabulated function.
* Get the parameters for the tabulated function.
*
*
* @param values the tabulated values of the function f(x). The function is only defined
* @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/
*/
void
getFunctionParameters
(
std
::
vector
<
double
>&
values
)
const
;
void
getFunctionParameters
(
std
::
vector
<
double
>&
values
)
const
;
/**
/**
* Set the parameters for the tabulated function.
* Set the parameters for the tabulated function.
*
*
* @param values the tabulated values of the function f(x). The function is only defined
* @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/
*/
void
setFunctionParameters
(
const
std
::
vector
<
double
>&
values
);
void
setFunctionParameters
(
const
std
::
vector
<
double
>&
values
);
private:
private:
std
::
vector
<
double
>
values
;
std
::
vector
<
double
>
values
;
};
};
/**
* This is a TabulatedFunction that computes a discrete two dimensional function f(x,y).
* To evaluate it, x and y are each rounded to the nearest integer and the table element with those
* indices is returned. If either index is outside the range [0, size), the result is undefined.
*/
class
OPENMM_EXPORT
Discrete2DFunction
:
public
TabulatedFunction
{
public:
/**
* Create a Discrete2DFunction f(x,y) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
Discrete2DFunction
(
int
xsize
,
int
ysize
,
const
std
::
vector
<
double
>&
values
);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void
getFunctionParameters
(
int
&
xsize
,
int
&
ysize
,
std
::
vector
<
double
>&
values
)
const
;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void
setFunctionParameters
(
int
xsize
,
int
ysize
,
const
std
::
vector
<
double
>&
values
);
private:
int
xsize
,
ysize
;
std
::
vector
<
double
>
values
;
};
/**
* This is a TabulatedFunction that computes a discrete three dimensional function f(x,y,z).
* To evaluate it, x, y, and z are each rounded to the nearest integer and the table element with those
* indices is returned. If any index is outside the range [0, size), the result is undefined.
*/
class
OPENMM_EXPORT
Discrete3DFunction
:
public
TabulatedFunction
{
public:
/**
* Create a Discrete3DFunction f(x,y,z) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
Discrete3DFunction
(
int
xsize
,
int
ysize
,
int
zsize
,
const
std
::
vector
<
double
>&
values
);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void
getFunctionParameters
(
int
&
xsize
,
int
&
ysize
,
int
&
zsize
,
std
::
vector
<
double
>&
values
)
const
;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void
setFunctionParameters
(
int
xsize
,
int
ysize
,
int
zsize
,
const
std
::
vector
<
double
>&
values
);
private:
int
xsize
,
ysize
,
zsize
;
std
::
vector
<
double
>
values
;
};
}
// namespace OpenMM
}
// namespace OpenMM
#endif
/*OPENMM_TABULATEDFUNCTION_H_*/
#endif
/*OPENMM_TABULATEDFUNCTION_H_*/
openmmapi/src/TabulatedFunction.cpp
View file @
a773952e
...
@@ -72,3 +72,50 @@ void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) cons
...
@@ -72,3 +72,50 @@ void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) cons
void
Discrete1DFunction
::
setFunctionParameters
(
const
std
::
vector
<
double
>&
values
)
{
void
Discrete1DFunction
::
setFunctionParameters
(
const
std
::
vector
<
double
>&
values
)
{
this
->
values
=
values
;
this
->
values
=
values
;
}
}
Discrete2DFunction
::
Discrete2DFunction
(
int
xsize
,
int
ysize
,
const
std
::
vector
<
double
>&
values
)
{
if
(
values
.
size
()
!=
xsize
*
ysize
)
throw
OpenMMException
(
"Discrete2DFunction: incorrect number of values"
);
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
values
=
values
;
}
void
Discrete2DFunction
::
getFunctionParameters
(
int
&
xsize
,
int
&
ysize
,
std
::
vector
<
double
>&
values
)
const
{
xsize
=
this
->
xsize
;
ysize
=
this
->
ysize
;
values
=
this
->
values
;
}
void
Discrete2DFunction
::
setFunctionParameters
(
int
xsize
,
int
ysize
,
const
std
::
vector
<
double
>&
values
)
{
if
(
values
.
size
()
!=
xsize
*
ysize
)
throw
OpenMMException
(
"Discrete2DFunction: incorrect number of values"
);
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
values
=
values
;
}
Discrete3DFunction
::
Discrete3DFunction
(
int
xsize
,
int
ysize
,
int
zsize
,
const
std
::
vector
<
double
>&
values
)
{
if
(
values
.
size
()
!=
xsize
*
ysize
*
zsize
)
throw
OpenMMException
(
"Discrete3DFunction: incorrect number of values"
);
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
zsize
=
zsize
;
this
->
values
=
values
;
}
void
Discrete3DFunction
::
getFunctionParameters
(
int
&
xsize
,
int
&
ysize
,
int
&
zsize
,
std
::
vector
<
double
>&
values
)
const
{
xsize
=
this
->
xsize
;
ysize
=
this
->
ysize
;
zsize
=
this
->
zsize
;
values
=
this
->
values
;
}
void
Discrete3DFunction
::
setFunctionParameters
(
int
xsize
,
int
ysize
,
int
zsize
,
const
std
::
vector
<
double
>&
values
)
{
if
(
values
.
size
()
!=
xsize
*
ysize
*
zsize
)
throw
OpenMMException
(
"Discrete3DFunction: incorrect number of values"
);
this
->
xsize
=
xsize
;
this
->
ysize
=
ysize
;
this
->
zsize
=
zsize
;
this
->
values
=
values
;
}
platforms/cuda/include/CudaExpressionUtilities.h
View file @
a773952e
...
@@ -46,8 +46,7 @@ namespace OpenMM {
...
@@ -46,8 +46,7 @@ namespace OpenMM {
class
OPENMM_EXPORT_CUDA
CudaExpressionUtilities
{
class
OPENMM_EXPORT_CUDA
CudaExpressionUtilities
{
public:
public:
CudaExpressionUtilities
(
CudaContext
&
context
)
:
context
(
context
)
{
CudaExpressionUtilities
(
CudaContext
&
context
);
}
/**
/**
* Generate the source code for calculating a set of expressions.
* Generate the source code for calculating a set of expressions.
*
*
...
@@ -93,38 +92,43 @@ public:
...
@@ -93,38 +92,43 @@ public:
* @return the parameter array
* @return the parameter array
*/
*/
std
::
vector
<
float4
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
std
::
vector
<
float4
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
class
FunctionPlaceholder
;
/**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
* @param function the function for which to get a placeholder
*/
Lepton
::
CustomFunction
*
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
);
private:
private:
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
FunctionPlaceholder
(
int
numArgs
)
:
numArgs
(
numArgs
)
{
}
int
getNumArguments
()
const
{
return
numArgs
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
(
numArgs
);
}
private:
int
numArgs
;
};
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
const
Lepton
::
ExpressionTreeNode
*&
valueNode
,
const
Lepton
::
ExpressionTreeNode
*&
derivN
ode
);
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*
>
&
n
ode
s
);
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
CudaContext
&
context
;
CudaContext
&
context
;
};
FunctionPlaceholder
fp1
,
fp2
,
fp3
;
/**
* This class serves as a placeholder for custom functions in expressions.
*/
class
CudaExpressionUtilities
::
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
int
getNumArguments
()
const
{
return
1
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
();
}
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
platforms/cuda/src/CudaExpressionUtilities.cpp
View file @
a773952e
...
@@ -33,6 +33,9 @@ using namespace OpenMM;
...
@@ -33,6 +33,9 @@ using namespace OpenMM;
using
namespace
Lepton
;
using
namespace
Lepton
;
using
namespace
std
;
using
namespace
std
;
CudaExpressionUtilities
::
CudaExpressionUtilities
(
CudaContext
&
context
)
:
context
(
context
),
fp1
(
1
),
fp2
(
2
),
fp3
(
3
)
{
}
string
CudaExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
string
CudaExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
const
TabulatedFunction
*>&
functions
,
const
vector
<
pair
<
string
,
string
>
>&
functionNames
,
const
string
&
prefix
,
const
vector
<
const
TabulatedFunction
*>&
functions
,
const
vector
<
pair
<
string
,
string
>
>&
functionNames
,
const
string
&
prefix
,
const
string
&
functionParams
,
const
string
&
tempType
)
{
const
string
&
functionParams
,
const
string
&
tempType
)
{
...
@@ -82,7 +85,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -82,7 +85,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
;
;
if
(
i
==
functionNames
.
size
())
if
(
i
==
functionNames
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
bool
isDeriv
=
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
1
);
out
<<
"0.0f;
\n
"
;
out
<<
"0.0f;
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
temps
.
push_back
(
make_pair
(
node
,
name
));
hasRecordedNode
=
true
;
hasRecordedNode
=
true
;
...
@@ -90,23 +92,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -90,23 +92,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
// If both the value and derivative of the function are needed, it's faster to calculate them both
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
// at once, so check to see if both are needed.
const
ExpressionTreeNode
*
valueNode
=
NULL
;
vector
<
const
ExpressionTreeNode
*>
nodes
;
const
ExpressionTreeNode
*
derivNode
=
NULL
;
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
valueNode
,
derivNode
);
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
nodes
);
string
valueName
=
name
;
vector
<
string
>
nodeNames
;
string
derivName
=
name
;
nodeNames
.
push_back
(
name
)
;
if
(
valueNode
!=
NULL
&&
derivNode
!=
NULL
)
{
for
(
int
j
=
1
;
j
<
(
int
)
nodes
.
size
();
j
++
)
{
string
name2
=
prefix
+
context
.
intToString
(
temps
.
size
());
string
name2
=
prefix
+
context
.
intToString
(
temps
.
size
());
out
<<
tempType
<<
" "
<<
name2
<<
" = 0.0f;
\n
"
;
out
<<
tempType
<<
" "
<<
name2
<<
" = 0.0f;
\n
"
;
if
(
isDeriv
)
{
nodeNames
.
push_back
(
name2
);
valueName
=
name2
;
temps
.
push_back
(
make_pair
(
*
nodes
[
j
],
name2
));
temps
.
push_back
(
make_pair
(
*
valueNode
,
name2
));
}
else
{
derivName
=
name2
;
temps
.
push_back
(
make_pair
(
*
derivNode
,
name2
));
}
}
}
out
<<
"{
\n
"
;
out
<<
"{
\n
"
;
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
...
@@ -119,20 +114,58 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -119,20 +114,58 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out
<<
"float4 coeff = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"float4 coeff = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"real b = x-index;
\n
"
;
out
<<
"real b = x-index;
\n
"
;
out
<<
"real a = 1.0f-b;
\n
"
;
out
<<
"real a = 1.0f-b;
\n
"
;
if
(
valueNode
!=
NULL
)
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
out
<<
valueName
<<
" = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
\n
"
;
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivNode
!=
NULL
)
if
(
derivOrder
[
0
]
==
0
)
out
<<
derivName
<<
" = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
\n
"
;
else
out
<<
nodeNames
[
j
]
<<
" = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;
\n
"
;
}
out
<<
"}
\n
"
;
out
<<
"}
\n
"
;
}
}
else
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
else
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
if
(
valueNode
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
out
<<
"real x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
if
(
derivOrder
[
0
]
==
0
)
{
out
<<
"if (x >= 0 && x < params.x) {
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int index = (int) round(x);
\n
"
;
out
<<
"real x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
valueName
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"if (x >= 0 && x < params.x) {
\n
"
;
out
<<
"}
\n
"
;
out
<<
"int index = (int) round(x);
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"}
\n
"
;
}
}
}
else
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivOrder
[
0
]
==
0
&&
derivOrder
[
1
]
==
0
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int x = (int) round("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
");
\n
"
;
out
<<
"int y = (int) round("
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
");
\n
"
;
out
<<
"int xsize = (int) params.x;
\n
"
;
out
<<
"int ysize = (int) params.y;
\n
"
;
out
<<
"int index = x+y*xsize;
\n
"
;
out
<<
"if (index >= 0 && index < xsize*ysize)
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
}
}
}
else
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivOrder
[
0
]
==
0
&&
derivOrder
[
1
]
==
0
&&
derivOrder
[
2
]
==
0
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int x = (int) round("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
");
\n
"
;
out
<<
"int y = (int) round("
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
");
\n
"
;
out
<<
"int z = (int) round("
<<
getTempName
(
node
.
getChildren
()[
2
],
temps
)
<<
");
\n
"
;
out
<<
"int xsize = (int) params.x;
\n
"
;
out
<<
"int ysize = (int) params.y;
\n
"
;
out
<<
"int zsize = (int) params.z;
\n
"
;
out
<<
"int index = x+(y+z*ysize)*xsize;
\n
"
;
out
<<
"if (index >= 0 && index < xsize*ysize*zsize)
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
}
}
}
}
}
out
<<
"}"
;
out
<<
"}"
;
...
@@ -327,16 +360,12 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
...
@@ -327,16 +360,12 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
}
}
void
CudaExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
void
CudaExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
const
ExpressionTreeNode
*&
valueNode
,
const
ExpressionTreeNode
*&
derivNode
)
{
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
)
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
searchNode
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
nodes
.
push_back
(
&
searchNode
);
valueNode
=
&
searchNode
;
else
derivNode
=
&
searchNode
;
}
else
else
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
valueNode
,
derivNode
);
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
nodes
);
}
}
void
CudaExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
void
CudaExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
...
@@ -392,6 +421,34 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
...
@@ -392,6 +421,34 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
width
=
1
;
width
=
1
;
return
f
;
return
f
;
}
}
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
&
function
)
!=
NULL
)
{
// Record the tabulated values.
const
Discrete2DFunction
&
fn
=
dynamic_cast
<
const
Discrete2DFunction
&>
(
function
);
int
xsize
,
ysize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
values
);
int
numValues
=
values
.
size
();
vector
<
float
>
f
(
numValues
);
for
(
int
i
=
0
;
i
<
numValues
;
i
++
)
f
[
i
]
=
(
float
)
values
[
i
];
width
=
1
;
return
f
;
}
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
&
function
)
!=
NULL
)
{
// Record the tabulated values.
const
Discrete3DFunction
&
fn
=
dynamic_cast
<
const
Discrete3DFunction
&>
(
function
);
int
xsize
,
ysize
,
zsize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
zsize
,
values
);
int
numValues
=
values
.
size
();
vector
<
float
>
f
(
numValues
);
for
(
int
i
=
0
;
i
<
numValues
;
i
++
)
f
[
i
]
=
(
float
)
values
[
i
];
width
=
1
;
return
f
;
}
throw
OpenMMException
(
"computeFunctionCoefficients: Unknown function type"
);
throw
OpenMMException
(
"computeFunctionCoefficients: Unknown function type"
);
}
}
...
@@ -411,8 +468,34 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c
...
@@ -411,8 +468,34 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c
fn
.
getFunctionParameters
(
values
);
fn
.
getFunctionParameters
(
values
);
params
[
i
]
=
make_float4
((
float
)
values
.
size
(),
0.0
f
,
0.0
f
,
0.0
f
);
params
[
i
]
=
make_float4
((
float
)
values
.
size
(),
0.0
f
,
0.0
f
,
0.0
f
);
}
}
else
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
const
Discrete2DFunction
&
fn
=
dynamic_cast
<
const
Discrete2DFunction
&>
(
*
functions
[
i
]);
int
xsize
,
ysize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
values
);
params
[
i
]
=
make_float4
(
xsize
,
ysize
,
0.0
f
,
0.0
f
);
}
else
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
const
Discrete3DFunction
&
fn
=
dynamic_cast
<
const
Discrete3DFunction
&>
(
*
functions
[
i
]);
int
xsize
,
ysize
,
zsize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
zsize
,
values
);
params
[
i
]
=
make_float4
(
xsize
,
ysize
,
zsize
,
0.0
f
);
}
else
else
throw
OpenMMException
(
"computeFunctionParameters: Unknown function type"
);
throw
OpenMMException
(
"computeFunctionParameters: Unknown function type"
);
}
}
return
params
;
return
params
;
}
}
Lepton
::
CustomFunction
*
CudaExpressionUtilities
::
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
)
{
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp1
;
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp1
;
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp2
;
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp3
;
throw
OpenMMException
(
"getFunctionPlaceholder: Unknown function type"
);
}
platforms/cuda/src/CudaKernels.cpp
View file @
a773952e
...
@@ -1958,7 +1958,6 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
...
@@ -1958,7 +1958,6 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
// Record the tabulated functions.
// Record the tabulated functions.
CudaExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -1967,7 +1966,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
...
@@ -1967,7 +1966,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
prefix
+
"table"
+
cu
.
intToString
(
i
);
string
arrayName
=
prefix
+
"table"
+
cu
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cu
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -2671,7 +2670,6 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
...
@@ -2671,7 +2670,6 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
// Record the tabulated functions.
// Record the tabulated functions.
CudaExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -2681,7 +2679,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
...
@@ -2681,7 +2679,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
prefix
+
"table"
+
cu
.
intToString
(
i
);
string
arrayName
=
prefix
+
"table"
+
cu
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cu
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -3786,7 +3784,6 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
...
@@ -3786,7 +3784,6 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Record the tabulated functions.
// Record the tabulated functions.
CudaExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -3796,7 +3793,7 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
...
@@ -3796,7 +3793,7 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
"table"
+
cu
.
intToString
(
i
);
string
arrayName
=
"table"
+
cu
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cu
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -4182,7 +4179,6 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
...
@@ -4182,7 +4179,6 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
// Record the tabulated functions.
// Record the tabulated functions.
CudaExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -4190,7 +4186,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
...
@@ -4190,7 +4186,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getFunction
(
i
));
functionList
.
push_back
(
&
force
.
getFunction
(
i
));
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cu
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cu
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
CudaArray
*
array
=
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
);
CudaArray
*
array
=
CudaArray
::
create
<
float
>
(
cu
,
f
.
size
(),
"TabulatedFunction"
);
...
...
platforms/cuda/tests/TestCudaCustomNonbondedForce.cpp
View file @
a773952e
...
@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
...
@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
...
@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
std
::
cos
(
x
-
1.0
));
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
cos
(
x
-
1.0
));
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
...
@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
...
@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Energy
);
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
}
}
}
}
...
@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
...
@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -327,6 +327,74 @@ void testDiscrete1DFunction() {
...
@@ -327,6 +327,74 @@ void testDiscrete1DFunction() {
}
}
}
}
void
testDiscrete2DFunction
()
{
const
int
xsize
=
10
;
const
int
ysize
=
5
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r-1,a)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
));
forceField
->
addFunction
(
"fn"
,
new
Discrete2DFunction
(
xsize
,
ysize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
((
i
%
xsize
)
+
1
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
i
/
xsize
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
}
}
void
testDiscrete3DFunction
()
{
const
int
xsize
=
8
;
const
int
ysize
=
5
;
const
int
zsize
=
6
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r-1,a,b)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addGlobalParameter
(
"b"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
for
(
int
k
=
0
;
k
<
zsize
;
k
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
)
+
0.12345
*
k
);
forceField
->
addFunction
(
"fn"
,
new
Discrete3DFunction
(
xsize
,
ysize
,
zsize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
((
i
%
xsize
)
+
1
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
(
i
/
xsize
)
%
ysize
);
context
.
setParameter
(
"b"
,
i
/
(
xsize
*
ysize
));
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
}
}
void
testCoulombLennardJones
()
{
void
testCoulombLennardJones
()
{
const
int
numMolecules
=
300
;
const
int
numMolecules
=
300
;
const
int
numParticles
=
numMolecules
*
2
;
const
int
numParticles
=
numMolecules
*
2
;
...
@@ -754,6 +822,8 @@ int main(int argc, char* argv[]) {
...
@@ -754,6 +822,8 @@ int main(int argc, char* argv[]) {
testPeriodic
();
testPeriodic
();
testContinuous1DFunction
();
testContinuous1DFunction
();
testDiscrete1DFunction
();
testDiscrete1DFunction
();
testDiscrete2DFunction
();
testDiscrete3DFunction
();
testCoulombLennardJones
();
testCoulombLennardJones
();
testParallelComputation
();
testParallelComputation
();
testSwitchingFunction
();
testSwitchingFunction
();
...
...
platforms/opencl/include/OpenCLExpressionUtilities.h
View file @
a773952e
...
@@ -46,8 +46,7 @@ namespace OpenMM {
...
@@ -46,8 +46,7 @@ namespace OpenMM {
class
OPENMM_EXPORT_OPENCL
OpenCLExpressionUtilities
{
class
OPENMM_EXPORT_OPENCL
OpenCLExpressionUtilities
{
public:
public:
OpenCLExpressionUtilities
(
OpenCLContext
&
context
)
:
context
(
context
)
{
OpenCLExpressionUtilities
(
OpenCLContext
&
context
);
}
/**
/**
* Generate the source code for calculating a set of expressions.
* Generate the source code for calculating a set of expressions.
*
*
...
@@ -93,38 +92,43 @@ public:
...
@@ -93,38 +92,43 @@ public:
* @return the parameter array
* @return the parameter array
*/
*/
std
::
vector
<
mm_float4
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
std
::
vector
<
mm_float4
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
class
FunctionPlaceholder
;
/**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
* @param function the function for which to get a placeholder
*/
Lepton
::
CustomFunction
*
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
);
private:
private:
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
FunctionPlaceholder
(
int
numArgs
)
:
numArgs
(
numArgs
)
{
}
int
getNumArguments
()
const
{
return
numArgs
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
(
numArgs
);
}
private:
int
numArgs
;
};
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functionNames
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
,
const
std
::
string
&
tempType
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
const
Lepton
::
ExpressionTreeNode
*&
valueNode
,
const
Lepton
::
ExpressionTreeNode
*&
derivN
ode
);
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*
>
&
n
ode
s
);
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
OpenCLContext
&
context
;
OpenCLContext
&
context
;
};
FunctionPlaceholder
fp1
,
fp2
,
fp3
;
/**
* This class serves as a placeholder for custom functions in expressions.
*/
class
OpenCLExpressionUtilities
::
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
int
getNumArguments
()
const
{
return
1
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
();
}
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
a773952e
...
@@ -33,6 +33,9 @@ using namespace OpenMM;
...
@@ -33,6 +33,9 @@ using namespace OpenMM;
using
namespace
Lepton
;
using
namespace
Lepton
;
using
namespace
std
;
using
namespace
std
;
OpenCLExpressionUtilities
::
OpenCLExpressionUtilities
(
OpenCLContext
&
context
)
:
context
(
context
),
fp1
(
1
),
fp2
(
2
),
fp3
(
3
)
{
}
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
const
TabulatedFunction
*>&
functions
,
const
vector
<
pair
<
string
,
string
>
>&
functionNames
,
const
string
&
prefix
,
const
vector
<
const
TabulatedFunction
*>&
functions
,
const
vector
<
pair
<
string
,
string
>
>&
functionNames
,
const
string
&
prefix
,
const
string
&
functionParams
,
const
string
&
tempType
)
{
const
string
&
functionParams
,
const
string
&
tempType
)
{
...
@@ -82,7 +85,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -82,7 +85,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
;
;
if
(
i
==
functionNames
.
size
())
if
(
i
==
functionNames
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
bool
isDeriv
=
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
1
);
out
<<
"0.0f;
\n
"
;
out
<<
"0.0f;
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
temps
.
push_back
(
make_pair
(
node
,
name
));
hasRecordedNode
=
true
;
hasRecordedNode
=
true
;
...
@@ -90,23 +92,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -90,23 +92,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
// If both the value and derivative of the function are needed, it's faster to calculate them both
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
// at once, so check to see if both are needed.
const
ExpressionTreeNode
*
valueNode
=
NULL
;
vector
<
const
ExpressionTreeNode
*>
nodes
;
const
ExpressionTreeNode
*
derivNode
=
NULL
;
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
valueNode
,
derivNode
);
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
nodes
);
string
valueName
=
name
;
vector
<
string
>
nodeNames
;
string
derivName
=
name
;
nodeNames
.
push_back
(
name
)
;
if
(
valueNode
!=
NULL
&&
derivNode
!=
NULL
)
{
for
(
int
j
=
1
;
j
<
(
int
)
nodes
.
size
();
j
++
)
{
string
name2
=
prefix
+
context
.
intToString
(
temps
.
size
());
string
name2
=
prefix
+
context
.
intToString
(
temps
.
size
());
out
<<
tempType
<<
" "
<<
name2
<<
" = 0.0f;
\n
"
;
out
<<
tempType
<<
" "
<<
name2
<<
" = 0.0f;
\n
"
;
if
(
isDeriv
)
{
nodeNames
.
push_back
(
name2
);
valueName
=
name2
;
temps
.
push_back
(
make_pair
(
*
nodes
[
j
],
name2
));
temps
.
push_back
(
make_pair
(
*
valueNode
,
name2
));
}
else
{
derivName
=
name2
;
temps
.
push_back
(
make_pair
(
*
derivNode
,
name2
));
}
}
}
out
<<
"{
\n
"
;
out
<<
"{
\n
"
;
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
...
@@ -119,20 +114,58 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -119,20 +114,58 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out
<<
"float4 coeff = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"float4 coeff = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"real b = x-index;
\n
"
;
out
<<
"real b = x-index;
\n
"
;
out
<<
"real a = 1.0f-b;
\n
"
;
out
<<
"real a = 1.0f-b;
\n
"
;
if
(
valueNode
!=
NULL
)
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
out
<<
valueName
<<
" = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
\n
"
;
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivNode
!=
NULL
)
if
(
derivOrder
[
0
]
==
0
)
out
<<
derivName
<<
" = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
\n
"
;
else
out
<<
nodeNames
[
j
]
<<
" = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;
\n
"
;
}
out
<<
"}
\n
"
;
out
<<
"}
\n
"
;
}
}
else
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
else
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
if
(
valueNode
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
out
<<
"real x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
if
(
derivOrder
[
0
]
==
0
)
{
out
<<
"if (x >= 0 && x < params.x) {
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int index = (int) round(x);
\n
"
;
out
<<
"real x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
valueName
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"if (x >= 0 && x < params.x) {
\n
"
;
out
<<
"}
\n
"
;
out
<<
"int index = (int) round(x);
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"}
\n
"
;
}
}
}
else
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivOrder
[
0
]
==
0
&&
derivOrder
[
1
]
==
0
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int x = (int) round("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
");
\n
"
;
out
<<
"int y = (int) round("
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
");
\n
"
;
out
<<
"int xsize = (int) params.x;
\n
"
;
out
<<
"int ysize = (int) params.y;
\n
"
;
out
<<
"int index = x+y*xsize;
\n
"
;
out
<<
"if (index >= 0 && index < xsize*ysize)
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
}
}
}
else
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
for
(
int
j
=
0
;
j
<
nodes
.
size
();
j
++
)
{
const
vector
<
int
>&
derivOrder
=
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
nodes
[
j
]
->
getOperation
())
->
getDerivOrder
();
if
(
derivOrder
[
0
]
==
0
&&
derivOrder
[
1
]
==
0
&&
derivOrder
[
2
]
==
0
)
{
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"int x = (int) round("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
");
\n
"
;
out
<<
"int y = (int) round("
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
");
\n
"
;
out
<<
"int z = (int) round("
<<
getTempName
(
node
.
getChildren
()[
2
],
temps
)
<<
");
\n
"
;
out
<<
"int xsize = (int) params.x;
\n
"
;
out
<<
"int ysize = (int) params.y;
\n
"
;
out
<<
"int zsize = (int) params.z;
\n
"
;
out
<<
"int index = x+(y+z*ysize)*xsize;
\n
"
;
out
<<
"if (index >= 0 && index < xsize*ysize*zsize)
\n
"
;
out
<<
nodeNames
[
j
]
<<
" = "
<<
functionNames
[
i
].
second
<<
"[index];
\n
"
;
}
}
}
}
}
out
<<
"}"
;
out
<<
"}"
;
...
@@ -327,16 +360,12 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
...
@@ -327,16 +360,12 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
}
}
void
OpenCLExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
void
OpenCLExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
const
ExpressionTreeNode
*&
valueNode
,
const
ExpressionTreeNode
*&
derivNode
)
{
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
)
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
searchNode
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
nodes
.
push_back
(
&
searchNode
);
valueNode
=
&
searchNode
;
else
derivNode
=
&
searchNode
;
}
else
else
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
valueNode
,
derivNode
);
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
nodes
);
}
}
void
OpenCLExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
void
OpenCLExpressionUtilities
::
findRelatedPowers
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
map
<
int
,
const
ExpressionTreeNode
*>&
powers
)
{
...
@@ -392,6 +421,34 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
...
@@ -392,6 +421,34 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
width
=
1
;
width
=
1
;
return
f
;
return
f
;
}
}
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
&
function
)
!=
NULL
)
{
// Record the tabulated values.
const
Discrete2DFunction
&
fn
=
dynamic_cast
<
const
Discrete2DFunction
&>
(
function
);
int
xsize
,
ysize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
values
);
int
numValues
=
values
.
size
();
vector
<
float
>
f
(
numValues
);
for
(
int
i
=
0
;
i
<
numValues
;
i
++
)
f
[
i
]
=
(
float
)
values
[
i
];
width
=
1
;
return
f
;
}
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
&
function
)
!=
NULL
)
{
// Record the tabulated values.
const
Discrete3DFunction
&
fn
=
dynamic_cast
<
const
Discrete3DFunction
&>
(
function
);
int
xsize
,
ysize
,
zsize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
zsize
,
values
);
int
numValues
=
values
.
size
();
vector
<
float
>
f
(
numValues
);
for
(
int
i
=
0
;
i
<
numValues
;
i
++
)
f
[
i
]
=
(
float
)
values
[
i
];
width
=
1
;
return
f
;
}
throw
OpenMMException
(
"computeFunctionCoefficients: Unknown function type"
);
throw
OpenMMException
(
"computeFunctionCoefficients: Unknown function type"
);
}
}
...
@@ -411,8 +468,34 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec
...
@@ -411,8 +468,34 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec
fn
.
getFunctionParameters
(
values
);
fn
.
getFunctionParameters
(
values
);
params
[
i
]
=
mm_float4
((
float
)
values
.
size
(),
0.0
f
,
0.0
f
,
0.0
f
);
params
[
i
]
=
mm_float4
((
float
)
values
.
size
(),
0.0
f
,
0.0
f
,
0.0
f
);
}
}
else
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
const
Discrete2DFunction
&
fn
=
dynamic_cast
<
const
Discrete2DFunction
&>
(
*
functions
[
i
]);
int
xsize
,
ysize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
values
);
params
[
i
]
=
mm_float4
(
xsize
,
ysize
,
0.0
f
,
0.0
f
);
}
else
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
functions
[
i
])
!=
NULL
)
{
const
Discrete3DFunction
&
fn
=
dynamic_cast
<
const
Discrete3DFunction
&>
(
*
functions
[
i
]);
int
xsize
,
ysize
,
zsize
;
vector
<
double
>
values
;
fn
.
getFunctionParameters
(
xsize
,
ysize
,
zsize
,
values
);
params
[
i
]
=
mm_float4
(
xsize
,
ysize
,
zsize
,
0.0
f
);
}
else
else
throw
OpenMMException
(
"computeFunctionParameters: Unknown function type"
);
throw
OpenMMException
(
"computeFunctionParameters: Unknown function type"
);
}
}
return
params
;
return
params
;
}
}
Lepton
::
CustomFunction
*
OpenCLExpressionUtilities
::
getFunctionPlaceholder
(
const
TabulatedFunction
&
function
)
{
if
(
dynamic_cast
<
const
Continuous1DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp1
;
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp1
;
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp2
;
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
&
function
)
!=
NULL
)
return
&
fp3
;
throw
OpenMMException
(
"getFunctionPlaceholder: Unknown function type"
);
}
platforms/opencl/src/OpenCLKernels.cpp
View file @
a773952e
...
@@ -1968,7 +1968,6 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -1968,7 +1968,6 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record the tabulated functions.
// Record the tabulated functions.
OpenCLExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -1977,7 +1976,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -1977,7 +1976,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
prefix
+
"table"
+
cl
.
intToString
(
i
);
string
arrayName
=
prefix
+
"table"
+
cl
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cl
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -2724,7 +2723,6 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
...
@@ -2724,7 +2723,6 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record the tabulated functions.
// Record the tabulated functions.
OpenCLExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -2734,7 +2732,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
...
@@ -2734,7 +2732,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
prefix
+
"table"
+
cl
.
intToString
(
i
);
string
arrayName
=
prefix
+
"table"
+
cl
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cl
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -3949,7 +3947,6 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
...
@@ -3949,7 +3947,6 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Record the tabulated functions.
// Record the tabulated functions.
OpenCLExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -3959,7 +3956,7 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
...
@@ -3959,7 +3956,7 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
string
arrayName
=
"table"
+
cl
.
intToString
(
i
);
string
arrayName
=
"table"
+
cl
.
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cl
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
tabulatedFunctions
.
push_back
(
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
));
...
@@ -4347,7 +4344,6 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
...
@@ -4347,7 +4344,6 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
// Record the tabulated functions.
// Record the tabulated functions.
OpenCLExpressionUtilities
::
FunctionPlaceholder
fp
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
const
TabulatedFunction
*>
functionList
;
vector
<
const
TabulatedFunction
*>
functionList
;
...
@@ -4355,7 +4351,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
...
@@ -4355,7 +4351,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
functionList
.
push_back
(
&
force
.
getFunction
(
i
));
functionList
.
push_back
(
&
force
.
getFunction
(
i
));
string
name
=
force
.
getFunctionName
(
i
);
string
name
=
force
.
getFunctionName
(
i
);
functions
[
name
]
=
&
fp
;
functions
[
name
]
=
cl
.
getExpressionUtilities
().
getFunctionPlaceholder
(
force
.
getFunction
(
i
))
;
int
width
;
int
width
;
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
vector
<
float
>
f
=
cl
.
getExpressionUtilities
().
computeFunctionCoefficients
(
force
.
getFunction
(
i
),
width
);
OpenCLArray
*
array
=
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
);
OpenCLArray
*
array
=
OpenCLArray
::
create
<
float
>
(
cl
,
f
.
size
(),
"TabulatedFunction"
);
...
...
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
View file @
a773952e
...
@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
...
@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
...
@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
std
::
cos
(
x
-
1.0
));
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
cos
(
x
-
1.0
));
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
...
@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
...
@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Energy
);
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
}
}
}
}
...
@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
...
@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -326,6 +326,73 @@ void testDiscrete1DFunction() {
...
@@ -326,6 +326,73 @@ void testDiscrete1DFunction() {
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
}
}
}
}
void
testDiscrete2DFunction
()
{
const
int
xsize
=
10
;
const
int
ysize
=
5
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r-1,a)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
));
forceField
->
addFunction
(
"fn"
,
new
Discrete2DFunction
(
xsize
,
ysize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
((
i
%
xsize
)
+
1
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
i
/
xsize
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
}
}
void
testDiscrete3DFunction
()
{
const
int
xsize
=
8
;
const
int
ysize
=
5
;
const
int
zsize
=
6
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r-1,a,b)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addGlobalParameter
(
"b"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
for
(
int
k
=
0
;
k
<
zsize
;
k
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
)
+
0.12345
*
k
);
forceField
->
addFunction
(
"fn"
,
new
Discrete3DFunction
(
xsize
,
ysize
,
zsize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
((
i
%
xsize
)
+
1
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
(
i
/
xsize
)
%
ysize
);
context
.
setParameter
(
"b"
,
i
/
(
xsize
*
ysize
));
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL_TOL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
(),
1e-6
);
}
}
void
testCoulombLennardJones
()
{
void
testCoulombLennardJones
()
{
const
int
numMolecules
=
300
;
const
int
numMolecules
=
300
;
...
@@ -754,6 +821,8 @@ int main(int argc, char* argv[]) {
...
@@ -754,6 +821,8 @@ int main(int argc, char* argv[]) {
testPeriodic
();
testPeriodic
();
testContinuous1DFunction
();
testContinuous1DFunction
();
testDiscrete1DFunction
();
testDiscrete1DFunction
();
testDiscrete2DFunction
();
testDiscrete3DFunction
();
testCoulombLennardJones
();
testCoulombLennardJones
();
testParallelComputation
();
testParallelComputation
();
testSwitchingFunction
();
testSwitchingFunction
();
...
...
platforms/reference/include/ReferenceTabulatedFunction.h
View file @
a773952e
...
@@ -75,6 +75,38 @@ private:
...
@@ -75,6 +75,38 @@ private:
std
::
vector
<
double
>
values
;
std
::
vector
<
double
>
values
;
};
};
/**
* This class adapts a Discrete2DFunction into a Lepton::CustomFunction.
*/
class
OPENMM_EXPORT
ReferenceDiscrete2DFunction
:
public
Lepton
::
CustomFunction
{
public:
ReferenceDiscrete2DFunction
(
const
Discrete2DFunction
&
function
);
int
getNumArguments
()
const
;
double
evaluate
(
const
double
*
arguments
)
const
;
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
;
CustomFunction
*
clone
()
const
;
private:
const
Discrete2DFunction
&
function
;
int
xsize
,
ysize
;
std
::
vector
<
double
>
values
;
};
/**
* This class adapts a Discrete3DFunction into a Lepton::CustomFunction.
*/
class
OPENMM_EXPORT
ReferenceDiscrete3DFunction
:
public
Lepton
::
CustomFunction
{
public:
ReferenceDiscrete3DFunction
(
const
Discrete3DFunction
&
function
);
int
getNumArguments
()
const
;
double
evaluate
(
const
double
*
arguments
)
const
;
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
;
CustomFunction
*
clone
()
const
;
private:
const
Discrete3DFunction
&
function
;
int
xsize
,
ysize
,
zsize
;
std
::
vector
<
double
>
values
;
};
}
// namespace OpenMM
}
// namespace OpenMM
#endif
/*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
#endif
/*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
platforms/reference/src/ReferenceTabulatedFunction.cpp
View file @
a773952e
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include "ReferenceTabulatedFunction.h"
#include "ReferenceTabulatedFunction.h"
#include "openmm/OpenMMException.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "openmm/internal/SplineFitter.h"
#include <cmath>
using
namespace
OpenMM
;
using
namespace
OpenMM
;
using
namespace
std
;
using
namespace
std
;
...
@@ -42,6 +43,10 @@ extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunct
...
@@ -42,6 +43,10 @@ extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunct
return
new
ReferenceContinuous1DFunction
(
dynamic_cast
<
const
Continuous1DFunction
&>
(
function
));
return
new
ReferenceContinuous1DFunction
(
dynamic_cast
<
const
Continuous1DFunction
&>
(
function
));
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
&
function
)
!=
NULL
)
if
(
dynamic_cast
<
const
Discrete1DFunction
*>
(
&
function
)
!=
NULL
)
return
new
ReferenceDiscrete1DFunction
(
dynamic_cast
<
const
Discrete1DFunction
&>
(
function
));
return
new
ReferenceDiscrete1DFunction
(
dynamic_cast
<
const
Discrete1DFunction
&>
(
function
));
if
(
dynamic_cast
<
const
Discrete2DFunction
*>
(
&
function
)
!=
NULL
)
return
new
ReferenceDiscrete2DFunction
(
dynamic_cast
<
const
Discrete2DFunction
&>
(
function
));
if
(
dynamic_cast
<
const
Discrete3DFunction
*>
(
&
function
)
!=
NULL
)
return
new
ReferenceDiscrete3DFunction
(
dynamic_cast
<
const
Discrete3DFunction
&>
(
function
));
throw
OpenMMException
(
"createReferenceTabulatedFunction: Unknown function type"
);
throw
OpenMMException
(
"createReferenceTabulatedFunction: Unknown function type"
);
}
}
...
@@ -85,10 +90,10 @@ int ReferenceDiscrete1DFunction::getNumArguments() const {
...
@@ -85,10 +90,10 @@ int ReferenceDiscrete1DFunction::getNumArguments() const {
}
}
double
ReferenceDiscrete1DFunction
::
evaluate
(
const
double
*
arguments
)
const
{
double
ReferenceDiscrete1DFunction
::
evaluate
(
const
double
*
arguments
)
const
{
int
t
=
(
int
)
arguments
[
0
];
int
i
=
(
int
)
round
(
arguments
[
0
]
)
;
if
(
t
<
0
||
t
>=
values
.
size
())
if
(
i
<
0
||
i
>=
values
.
size
())
throw
OpenMMException
(
"ReferenceDiscrete1DFunction: argument out of range"
);
throw
OpenMMException
(
"ReferenceDiscrete1DFunction: argument out of range"
);
return
values
[
t
];
return
values
[
i
];
}
}
double
ReferenceDiscrete1DFunction
::
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
double
ReferenceDiscrete1DFunction
::
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
...
@@ -98,3 +103,52 @@ double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments,
...
@@ -98,3 +103,52 @@ double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments,
CustomFunction
*
ReferenceDiscrete1DFunction
::
clone
()
const
{
CustomFunction
*
ReferenceDiscrete1DFunction
::
clone
()
const
{
return
new
ReferenceDiscrete1DFunction
(
function
);
return
new
ReferenceDiscrete1DFunction
(
function
);
}
}
ReferenceDiscrete2DFunction
::
ReferenceDiscrete2DFunction
(
const
Discrete2DFunction
&
function
)
:
function
(
function
)
{
function
.
getFunctionParameters
(
xsize
,
ysize
,
values
);
}
int
ReferenceDiscrete2DFunction
::
getNumArguments
()
const
{
return
2
;
}
double
ReferenceDiscrete2DFunction
::
evaluate
(
const
double
*
arguments
)
const
{
int
i
=
(
int
)
round
(
arguments
[
0
]);
int
j
=
(
int
)
round
(
arguments
[
1
]);
if
(
i
<
0
||
i
>=
xsize
||
j
<
0
||
j
>=
ysize
)
throw
OpenMMException
(
"ReferenceDiscrete2DFunction: argument out of range"
);
return
values
[
i
+
j
*
xsize
];
}
double
ReferenceDiscrete2DFunction
::
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
ReferenceDiscrete2DFunction
::
clone
()
const
{
return
new
ReferenceDiscrete2DFunction
(
function
);
}
ReferenceDiscrete3DFunction
::
ReferenceDiscrete3DFunction
(
const
Discrete3DFunction
&
function
)
:
function
(
function
)
{
function
.
getFunctionParameters
(
xsize
,
ysize
,
zsize
,
values
);
}
int
ReferenceDiscrete3DFunction
::
getNumArguments
()
const
{
return
3
;
}
double
ReferenceDiscrete3DFunction
::
evaluate
(
const
double
*
arguments
)
const
{
int
i
=
(
int
)
round
(
arguments
[
0
]);
int
j
=
(
int
)
round
(
arguments
[
1
]);
int
k
=
(
int
)
round
(
arguments
[
2
]);
if
(
i
<
0
||
i
>=
xsize
||
j
<
0
||
j
>=
ysize
||
k
<
0
||
k
>=
zsize
)
throw
OpenMMException
(
"ReferenceDiscrete3DFunction: argument out of range"
);
return
values
[
i
+
(
j
+
k
*
ysize
)
*
xsize
];
}
double
ReferenceDiscrete3DFunction
::
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
ReferenceDiscrete3DFunction
::
clone
()
const
{
return
new
ReferenceDiscrete3DFunction
(
function
);
}
platforms/reference/tests/TestReferenceCustomNonbondedForce.cpp
View file @
a773952e
...
@@ -238,7 +238,7 @@ void testContinuous1DFunction() {
...
@@ -238,7 +238,7 @@ void testContinuous1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
forceField
->
addFunction
(
"fn"
,
new
Continuous1DFunction
(
table
,
1.0
,
6.0
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -251,8 +251,8 @@ void testContinuous1DFunction() {
...
@@ -251,8 +251,8 @@ void testContinuous1DFunction() {
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
std
::
cos
(
x
-
1.0
));
double
force
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
-
cos
(
x
-
1.0
));
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
-
force
,
0
,
0
),
forces
[
0
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_VEC
(
Vec3
(
force
,
0
,
0
),
forces
[
1
],
0.1
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
0.02
);
...
@@ -262,7 +262,7 @@ void testContinuous1DFunction() {
...
@@ -262,7 +262,7 @@ void testContinuous1DFunction() {
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
positions
[
1
]
=
Vec3
(
x
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
State
state
=
context
.
getState
(
State
::
Energy
);
State
state
=
context
.
getState
(
State
::
Energy
);
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
std
::
sin
(
x
-
1.0
))
+
1.0
;
double
energy
=
(
x
<
1.0
||
x
>
6.0
?
0.0
:
sin
(
x
-
1.0
))
+
1.0
;
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
ASSERT_EQUAL_TOL
(
energy
,
state
.
getPotentialEnergy
(),
1e-4
);
}
}
}
}
...
@@ -278,7 +278,7 @@ void testDiscrete1DFunction() {
...
@@ -278,7 +278,7 @@ void testDiscrete1DFunction() {
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
21
;
i
++
)
for
(
int
i
=
0
;
i
<
21
;
i
++
)
table
.
push_back
(
std
::
sin
(
0.25
*
i
));
table
.
push_back
(
sin
(
0.25
*
i
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
forceField
->
addFunction
(
"fn"
,
new
Discrete1DFunction
(
table
));
system
.
addForce
(
forceField
);
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
...
@@ -295,6 +295,76 @@ void testDiscrete1DFunction() {
...
@@ -295,6 +295,76 @@ void testDiscrete1DFunction() {
}
}
}
}
void
testDiscrete2DFunction
()
{
const
int
xsize
=
10
;
const
int
ysize
=
5
;
ReferencePlatform
platform
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r,a)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
));
forceField
->
addFunction
(
"fn"
,
new
Discrete2DFunction
(
xsize
,
ysize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
(
i
%
xsize
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
i
/
xsize
);
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
());
}
}
void
testDiscrete3DFunction
()
{
const
int
xsize
=
8
;
const
int
ysize
=
5
;
const
int
zsize
=
6
;
ReferencePlatform
platform
;
System
system
;
system
.
addParticle
(
1.0
);
system
.
addParticle
(
1.0
);
VerletIntegrator
integrator
(
0.01
);
CustomNonbondedForce
*
forceField
=
new
CustomNonbondedForce
(
"fn(r,a,b)+1"
);
forceField
->
addGlobalParameter
(
"a"
,
0.0
);
forceField
->
addGlobalParameter
(
"b"
,
0.0
);
forceField
->
addParticle
(
vector
<
double
>
());
forceField
->
addParticle
(
vector
<
double
>
());
vector
<
double
>
table
;
for
(
int
i
=
0
;
i
<
xsize
;
i
++
)
for
(
int
j
=
0
;
j
<
ysize
;
j
++
)
for
(
int
k
=
0
;
k
<
zsize
;
k
++
)
table
.
push_back
(
sin
(
0.25
*
i
)
+
cos
(
0.33
*
j
)
+
0.12345
*
k
);
forceField
->
addFunction
(
"fn"
,
new
Discrete3DFunction
(
xsize
,
ysize
,
zsize
,
table
));
system
.
addForce
(
forceField
);
Context
context
(
system
,
integrator
,
platform
);
vector
<
Vec3
>
positions
(
2
);
positions
[
0
]
=
Vec3
(
0
,
0
,
0
);
for
(
int
i
=
0
;
i
<
(
int
)
table
.
size
();
i
++
)
{
positions
[
1
]
=
Vec3
(
i
%
xsize
,
0
,
0
);
context
.
setPositions
(
positions
);
context
.
setParameter
(
"a"
,
(
i
/
xsize
)
%
ysize
);
context
.
setParameter
(
"b"
,
i
/
(
xsize
*
ysize
));
State
state
=
context
.
getState
(
State
::
Forces
|
State
::
Energy
);
const
vector
<
Vec3
>&
forces
=
state
.
getForces
();
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
0
],
1e-6
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
0
),
forces
[
1
],
1e-6
);
ASSERT_EQUAL
(
table
[
i
]
+
1.0
,
state
.
getPotentialEnergy
());
}
}
void
testCoulombLennardJones
()
{
void
testCoulombLennardJones
()
{
const
int
numMolecules
=
300
;
const
int
numMolecules
=
300
;
const
int
numParticles
=
numMolecules
*
2
;
const
int
numParticles
=
numMolecules
*
2
;
...
@@ -688,6 +758,8 @@ int main() {
...
@@ -688,6 +758,8 @@ int main() {
testPeriodic
();
testPeriodic
();
testContinuous1DFunction
();
testContinuous1DFunction
();
testDiscrete1DFunction
();
testDiscrete1DFunction
();
testDiscrete2DFunction
();
testDiscrete3DFunction
();
testCoulombLennardJones
();
testCoulombLennardJones
();
testSwitchingFunction
();
testSwitchingFunction
();
testLongRangeCorrection
();
testLongRangeCorrection
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment