Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
fdc59e96
Commit
fdc59e96
authored
May 09, 2018
by
Peter Eastman
Browse files
Bug fixes to vector expressions
parent
b53c6593
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
26 deletions
+86
-26
platforms/cuda/include/CudaExpressionUtilities.h
platforms/cuda/include/CudaExpressionUtilities.h
+1
-0
platforms/cuda/src/CudaExpressionUtilities.cpp
platforms/cuda/src/CudaExpressionUtilities.cpp
+68
-22
platforms/cuda/src/kernels/customIntegratorPerDof.cu
platforms/cuda/src/kernels/customIntegratorPerDof.cu
+1
-1
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+2
-2
tests/TestCustomIntegrator.h
tests/TestCustomIntegrator.h
+14
-1
No files found.
platforms/cuda/include/CudaExpressionUtilities.h
View file @
fdc59e96
...
@@ -122,6 +122,7 @@ private:
...
@@ -122,6 +122,7 @@ private:
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
);
std
::
vector
<
const
Lepton
::
ExpressionTreeNode
*>&
nodes
);
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
void
callFunction
(
std
::
stringstream
&
out
,
std
::
string
singleFn
,
std
::
string
doubleFn
,
const
std
::
string
&
arg
,
const
std
::
string
&
tempType
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
CudaContext
&
context
;
CudaContext
&
context
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
,
periodicDistance
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
,
periodicDistance
;
...
...
platforms/cuda/src/CudaExpressionUtilities.cpp
View file @
fdc59e96
...
@@ -445,62 +445,91 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -445,62 +445,91 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
break
;
case
Operation
::
SQRT
:
case
Operation
::
SQRT
:
out
<<
"SQRT("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"sqrtf"
,
"sqrt"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
);
break
;
break
;
case
Operation
::
EXP
:
case
Operation
::
EXP
:
out
<<
"EXP("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"expf"
,
"exp"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
LOG
:
case
Operation
::
LOG
:
out
<<
"LOG("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"logf"
,
"log"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
SIN
:
case
Operation
::
SIN
:
out
<<
"SIN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"sinf"
,
"sin"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
COS
:
case
Operation
::
COS
:
out
<<
"COS("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"cosf"
,
"cos"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
SEC
:
case
Operation
::
SEC
:
out
<<
"RECIP(COS("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"))"
;
out
<<
"1/"
;
callFunction
(
out
,
"cosf"
,
"cos"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
),
tempType
);
break
;
break
;
case
Operation
::
CSC
:
case
Operation
::
CSC
:
out
<<
"RECIP(SIN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"))"
;
out
<<
"1/"
;
callFunction
(
out
,
"sinf"
,
"sin"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
),
tempType
);
break
;
break
;
case
Operation
::
TAN
:
case
Operation
::
TAN
:
out
<<
"TAN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"tanf"
,
"tan"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
COT
:
case
Operation
::
COT
:
out
<<
"RECIP(TAN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"))"
;
out
<<
"1/"
;
callFunction
(
out
,
"tanf"
,
"tan"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
),
tempType
);
break
;
break
;
case
Operation
::
ASIN
:
case
Operation
::
ASIN
:
out
<<
"ASIN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"asinf"
,
"asin"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
ACOS
:
case
Operation
::
ACOS
:
out
<<
"ACOS("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"acosf"
,
"acos"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
ATAN
:
case
Operation
::
ATAN
:
out
<<
"ATAN("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"atanf"
,
"atan"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
SINH
:
case
Operation
::
SINH
:
out
<<
"sinh("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"sinh"
,
"sinh"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
COSH
:
case
Operation
::
COSH
:
out
<<
"cosh("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"cosh"
,
"cosh"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
TANH
:
case
Operation
::
TANH
:
out
<<
"tanh("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"tanh"
,
"tanh"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
ERF
:
case
Operation
::
ERF
:
out
<<
"erf("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"erf"
,
"erf"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
ERFC
:
case
Operation
::
ERFC
:
out
<<
"erfc("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"erfc"
,
"erfc"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
STEP
:
case
Operation
::
STEP
:
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" >= 0.0f ? 1.0f : 0.0f"
;
{
string
compareVal
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
if
(
isVecType
)
{
out
<<
"make_"
<<
tempType
<<
"(0);
\n
"
;
out
<<
"{
\n
"
;
out
<<
tempType
<<
" tempCompareValue = "
<<
compareVal
<<
";
\n
"
;
out
<<
name
<<
".x = (tempCompareValue.x >= 0 ? 1 : 0);
\n
"
;
out
<<
name
<<
".y = (tempCompareValue.y >= 0 ? 1 : 0);
\n
"
;
out
<<
name
<<
".z = (tempCompareValue.z >= 0 ? 1 : 0);
\n
"
;
out
<<
"}
\n
"
;
}
else
out
<<
compareVal
<<
" >= 0 ? 1 : 0"
;
break
;
break
;
}
case
Operation
::
DELTA
:
case
Operation
::
DELTA
:
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" == 0.0f ? 1.0f : 0.0f"
;
{
string
compareVal
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
if
(
isVecType
)
{
out
<<
"make_"
<<
tempType
<<
"(0);
\n
"
;
out
<<
"{
\n
"
;
out
<<
tempType
<<
" tempCompareValue = "
<<
compareVal
<<
";
\n
"
;
out
<<
name
<<
".x = (tempCompareValue.x == 0 ? 1 : 0);
\n
"
;
out
<<
name
<<
".y = (tempCompareValue.y == 0 ? 1 : 0);
\n
"
;
out
<<
name
<<
".z = (tempCompareValue.z == 0 ? 1 : 0);
\n
"
;
out
<<
"}
\n
"
;
}
else
out
<<
compareVal
<<
" == 0 ? 1 : 0"
;
break
;
break
;
}
case
Operation
::
SQUARE
:
case
Operation
::
SQUARE
:
{
{
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
...
@@ -586,13 +615,13 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -586,13 +615,13 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out
<<
"max(("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", ("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
out
<<
"max(("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", ("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
break
;
break
;
case
Operation
::
ABS
:
case
Operation
::
ABS
:
out
<<
"fabs("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"fabs"
,
"fabs"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
FLOOR
:
case
Operation
::
FLOOR
:
out
<<
"floor
("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"floor"
,
"floor
"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
CEIL
:
case
Operation
::
CEIL
:
out
<<
"ceil("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
callFunction
(
out
,
"ceil"
,
"ceil"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
SELECT
:
case
Operation
::
SELECT
:
{
{
...
@@ -880,3 +909,20 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
...
@@ -880,3 +909,20 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
Lepton
::
CustomFunction
*
CudaExpressionUtilities
::
getPeriodicDistancePlaceholder
()
{
Lepton
::
CustomFunction
*
CudaExpressionUtilities
::
getPeriodicDistancePlaceholder
()
{
return
&
periodicDistance
;
return
&
periodicDistance
;
}
}
void
CudaExpressionUtilities
::
callFunction
(
stringstream
&
out
,
string
singleFn
,
string
doubleFn
,
const
string
&
arg
,
const
string
&
tempType
)
{
bool
isDouble
=
(
tempType
[
0
]
==
'd'
);
bool
isVector
=
(
tempType
[
tempType
.
size
()
-
1
]
==
'3'
);
if
(
isVector
)
{
if
(
isDouble
)
out
<<
"make_double3("
<<
doubleFn
<<
"("
<<
arg
<<
".x), "
<<
doubleFn
<<
"("
<<
arg
<<
".y), "
<<
doubleFn
<<
"("
<<
arg
<<
".z))"
;
else
out
<<
"make_float3("
<<
singleFn
<<
"("
<<
arg
<<
".x), "
<<
singleFn
<<
"("
<<
arg
<<
".y), "
<<
singleFn
<<
"("
<<
arg
<<
".z))"
;
}
else
{
if
(
isDouble
)
out
<<
doubleFn
<<
"("
<<
arg
<<
")"
;
else
out
<<
singleFn
<<
"("
<<
arg
<<
")"
;
}
}
platforms/cuda/src/kernels/customIntegratorPerDof.cu
View file @
fdc59e96
...
@@ -51,7 +51,7 @@ extern "C" __global__ void computePerDof(real4* __restrict__ posq, real4* __rest
...
@@ -51,7 +51,7 @@ extern "C" __global__ void computePerDof(real4* __restrict__ posq, real4* __rest
#endif
#endif
double4
velocity
=
convertToDouble4
(
velm
[
index
]);
double4
velocity
=
convertToDouble4
(
velm
[
index
]);
double4
f
=
make_double4
(
forceScale
*
force
[
index
],
forceScale
*
force
[
index
+
PADDED_NUM_ATOMS
],
forceScale
*
force
[
index
+
PADDED_NUM_ATOMS
*
2
],
0.0
);
double4
f
=
make_double4
(
forceScale
*
force
[
index
],
forceScale
*
force
[
index
+
PADDED_NUM_ATOMS
],
forceScale
*
force
[
index
+
PADDED_NUM_ATOMS
*
2
],
0.0
);
double
mass
=
1.0
/
velocity
.
w
;
double
3
mass
=
make_double3
(
1.0
/
velocity
.
w
)
;
if
(
velocity
.
w
!=
0.0
)
{
if
(
velocity
.
w
!=
0.0
)
{
int
gaussianIndex
=
gaussianBaseIndex
;
int
gaussianIndex
=
gaussianBaseIndex
;
int
uniformIndex
=
0
;
int
uniformIndex
=
0
;
...
...
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
fdc59e96
...
@@ -478,10 +478,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -478,10 +478,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out
<<
"erfc("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
out
<<
"erfc("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
break
;
case
Operation
::
STEP
:
case
Operation
::
STEP
:
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" >= 0.0f ?
1.0f : 0.0f
"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" >= 0.0f ?
("
<<
tempType
<<
") 1 : ("
<<
tempType
<<
") 0
"
;
break
;
break
;
case
Operation
::
DELTA
:
case
Operation
::
DELTA
:
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" == 0.0f ?
1.0f : 0.0f
"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
" == 0.0f ?
("
<<
tempType
<<
") 1 : ("
<<
tempType
<<
") 0
"
;
break
;
break
;
case
Operation
::
SQUARE
:
case
Operation
::
SQUARE
:
{
{
...
...
tests/TestCustomIntegrator.h
View file @
fdc59e96
...
@@ -537,9 +537,11 @@ void testPerDofVariables() {
...
@@ -537,9 +537,11 @@ void testPerDofVariables() {
CustomIntegrator
integrator
(
0.01
);
CustomIntegrator
integrator
(
0.01
);
integrator
.
addPerDofVariable
(
"temp"
,
0
);
integrator
.
addPerDofVariable
(
"temp"
,
0
);
integrator
.
addPerDofVariable
(
"pos"
,
0
);
integrator
.
addPerDofVariable
(
"pos"
,
0
);
integrator
.
addPerDofVariable
(
"computed"
,
0
);
integrator
.
addComputePerDof
(
"v"
,
"v+dt*f/m"
);
integrator
.
addComputePerDof
(
"v"
,
"v+dt*f/m"
);
integrator
.
addComputePerDof
(
"x"
,
"x+dt*v"
);
integrator
.
addComputePerDof
(
"x"
,
"x+dt*v"
);
integrator
.
addComputePerDof
(
"pos"
,
"x"
);
integrator
.
addComputePerDof
(
"pos"
,
"x"
);
integrator
.
addComputePerDof
(
"computed"
,
"step(v)*log(x^2)"
);
Context
context
(
system
,
integrator
,
platform
);
Context
context
(
system
,
integrator
,
platform
);
context
.
setPositions
(
positions
);
context
.
setPositions
(
positions
);
vector
<
Vec3
>
initialValues
(
numParticles
);
vector
<
Vec3
>
initialValues
(
numParticles
);
...
@@ -552,13 +554,24 @@ void testPerDofVariables() {
...
@@ -552,13 +554,24 @@ void testPerDofVariables() {
vector
<
Vec3
>
values
;
vector
<
Vec3
>
values
;
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
integrator
.
step
(
1
);
integrator
.
step
(
1
);
State
state
=
context
.
getState
(
State
::
Positions
);
State
state
=
context
.
getState
(
State
::
Positions
|
State
::
Velocities
);
integrator
.
getPerDofVariable
(
0
,
values
);
integrator
.
getPerDofVariable
(
0
,
values
);
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
ASSERT_EQUAL_VEC
(
initialValues
[
j
],
values
[
j
],
1e-5
);
ASSERT_EQUAL_VEC
(
initialValues
[
j
],
values
[
j
],
1e-5
);
integrator
.
getPerDofVariable
(
1
,
values
);
integrator
.
getPerDofVariable
(
1
,
values
);
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
ASSERT_EQUAL_VEC
(
state
.
getPositions
()[
j
],
values
[
j
],
1e-5
);
ASSERT_EQUAL_VEC
(
state
.
getPositions
()[
j
],
values
[
j
],
1e-5
);
integrator
.
getPerDofVariable
(
2
,
values
);
for
(
int
j
=
0
;
j
<
numParticles
;
j
++
)
for
(
int
k
=
0
;
k
<
3
;
k
++
)
{
if
(
state
.
getVelocities
()[
j
][
k
]
<
0
)
{
ASSERT
(
values
[
j
][
k
]
==
0.0
);
}
else
{
double
v
=
state
.
getPositions
()[
j
][
k
];
ASSERT_EQUAL_TOL
(
log
(
v
*
v
),
values
[
j
][
k
],
1e-5
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment