Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
dd4eed16
Unverified
Commit
dd4eed16
authored
Feb 22, 2019
by
peastman
Committed by
GitHub
Feb 22, 2019
Browse files
Merge pull request #2266 from peastman/vector
Fixed bug in vector operations in CustomIntegrator on CUDA
parents
f36f2dba
ddcd2f66
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
4 deletions
+30
-4
platforms/cuda/include/CudaExpressionUtilities.h
platforms/cuda/include/CudaExpressionUtilities.h
+1
-0
platforms/cuda/src/CudaExpressionUtilities.cpp
platforms/cuda/src/CudaExpressionUtilities.cpp
+24
-3
tests/TestCustomIntegrator.h
tests/TestCustomIntegrator.h
+5
-1
No files found.
platforms/cuda/include/CudaExpressionUtilities.h
View file @
dd4eed16
...
@@ -123,6 +123,7 @@ private:
...
@@ -123,6 +123,7 @@ private:
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
void
findRelatedPowers
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
std
::
map
<
int
,
const
Lepton
::
ExpressionTreeNode
*>&
powers
);
void
callFunction
(
std
::
stringstream
&
out
,
std
::
string
singleFn
,
std
::
string
doubleFn
,
const
std
::
string
&
arg
,
const
std
::
string
&
tempType
);
void
callFunction
(
std
::
stringstream
&
out
,
std
::
string
singleFn
,
std
::
string
doubleFn
,
const
std
::
string
&
arg
,
const
std
::
string
&
tempType
);
void
callFunction2
(
std
::
stringstream
&
out
,
std
::
string
fn
,
const
std
::
string
&
arg1
,
const
std
::
string
&
arg2
,
const
std
::
string
&
tempType
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
std
::
vector
<
std
::
vector
<
double
>
>
computeFunctionParameters
(
const
std
::
vector
<
const
TabulatedFunction
*>&
functions
);
CudaContext
&
context
;
CudaContext
&
context
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
,
periodicDistance
;
FunctionPlaceholder
fp1
,
fp2
,
fp3
,
periodicDistance
;
...
...
platforms/cuda/src/CudaExpressionUtilities.cpp
View file @
dd4eed16
...
@@ -547,7 +547,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -547,7 +547,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out
<<
"RECIP("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
out
<<
"RECIP("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
break
;
case
Operation
::
ADD_CONSTANT
:
case
Operation
::
ADD_CONSTANT
:
out
<<
context
.
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
if
(
isVecType
)
{
string
val
=
context
.
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
());
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
"make_"
<<
tempType
<<
"("
;
out
<<
val
<<
"+"
<<
arg
<<
".x, "
;
out
<<
val
<<
"+"
<<
arg
<<
".y, "
;
out
<<
val
<<
"+"
<<
arg
<<
".z)"
;
}
else
out
<<
context
.
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
break
;
case
Operation
::
MULTIPLY_CONSTANT
:
case
Operation
::
MULTIPLY_CONSTANT
:
out
<<
context
.
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
context
.
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
...
@@ -610,10 +619,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
...
@@ -610,10 +619,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
break
;
break
;
}
}
case
Operation
::
MIN
:
case
Operation
::
MIN
:
out
<<
"min(("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", ("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
callFunction2
(
out
,
"min"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
getTempName
(
node
.
getChildren
()[
1
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
MAX
:
case
Operation
::
MAX
:
out
<<
"max(("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", ("
<<
tempType
<<
") "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
callFunction2
(
out
,
"max"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
)
,
getTempName
(
node
.
getChildren
()[
1
],
temps
)
,
tempType
)
;
break
;
break
;
case
Operation
::
ABS
:
case
Operation
::
ABS
:
callFunction
(
out
,
"fabs"
,
"fabs"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
),
tempType
);
callFunction
(
out
,
"fabs"
,
"fabs"
,
getTempName
(
node
.
getChildren
()[
0
],
temps
),
tempType
);
...
@@ -927,3 +936,15 @@ void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, s
...
@@ -927,3 +936,15 @@ void CudaExpressionUtilities::callFunction(stringstream& out, string singleFn, s
out
<<
singleFn
<<
"("
<<
arg
<<
")"
;
out
<<
singleFn
<<
"("
<<
arg
<<
")"
;
}
}
}
}
void
CudaExpressionUtilities
::
callFunction2
(
stringstream
&
out
,
string
fn
,
const
string
&
arg1
,
const
string
&
arg2
,
const
string
&
tempType
)
{
bool
isVector
=
(
tempType
[
tempType
.
size
()
-
1
]
==
'3'
);
if
(
isVector
)
{
out
<<
"make_"
<<
tempType
<<
"("
;
out
<<
fn
<<
"("
<<
arg1
<<
".x, "
<<
arg2
<<
".x), "
;
out
<<
fn
<<
"("
<<
arg1
<<
".y, "
<<
arg2
<<
".y), "
;
out
<<
fn
<<
"("
<<
arg1
<<
".z, "
<<
arg2
<<
".z))"
;
}
else
out
<<
fn
<<
"(("
<<
tempType
<<
") "
<<
arg1
<<
", ("
<<
tempType
<<
") "
<<
arg2
<<
")"
;
}
tests/TestCustomIntegrator.h
View file @
dd4eed16
...
@@ -1042,10 +1042,12 @@ void testVectorFunctions() {
...
@@ -1042,10 +1042,12 @@ void testVectorFunctions() {
integrator
.
addPerDofVariable
(
"angular"
,
0.0
);
integrator
.
addPerDofVariable
(
"angular"
,
0.0
);
integrator
.
addPerDofVariable
(
"shuffle"
,
0.0
);
integrator
.
addPerDofVariable
(
"shuffle"
,
0.0
);
integrator
.
addPerDofVariable
(
"multicross"
,
0.0
);
integrator
.
addPerDofVariable
(
"multicross"
,
0.0
);
integrator
.
addPerDofVariable
(
"maxplus"
,
0.0
);
integrator
.
addComputeSum
(
"sumy"
,
"x*vector(0, 1, 0)"
);
integrator
.
addComputeSum
(
"sumy"
,
"x*vector(0, 1, 0)"
);
integrator
.
addComputePerDof
(
"angular"
,
"cross(v, x)"
);
integrator
.
addComputePerDof
(
"angular"
,
"cross(v, x)"
);
integrator
.
addComputePerDof
(
"shuffle"
,
"dot(vector(_z(x), _x(x), _y(x)), v)"
);
integrator
.
addComputePerDof
(
"shuffle"
,
"dot(vector(_z(x), _x(x), _y(x)), v)"
);
integrator
.
addComputePerDof
(
"multicross"
,
"cross(vector(1, 0, 0), cross(vector(0, 0, 1), vector(1, 0, 0)))"
);
integrator
.
addComputePerDof
(
"multicross"
,
"cross(vector(1, 0, 0), cross(vector(0, 0, 1), vector(1, 0, 0)))"
);
integrator
.
addComputePerDof
(
"maxplus"
,
"max(x, 0.1)+0.5"
);
OpenMM_SFMT
::
SFMT
sfmt
;
OpenMM_SFMT
::
SFMT
sfmt
;
init_gen_rand
(
0
,
sfmt
);
init_gen_rand
(
0
,
sfmt
);
vector
<
Vec3
>
positions
(
numParticles
);
vector
<
Vec3
>
positions
(
numParticles
);
...
@@ -1063,14 +1065,16 @@ void testVectorFunctions() {
...
@@ -1063,14 +1065,16 @@ void testVectorFunctions() {
// See if the expressions were computed correctly.
// See if the expressions were computed correctly.
double
sumy
=
0
;
double
sumy
=
0
;
vector
<
Vec3
>
angular
,
shuffle
,
multicross
;
vector
<
Vec3
>
angular
,
shuffle
,
multicross
,
maxplus
;
integrator
.
getPerDofVariable
(
0
,
angular
);
integrator
.
getPerDofVariable
(
0
,
angular
);
integrator
.
getPerDofVariable
(
1
,
shuffle
);
integrator
.
getPerDofVariable
(
1
,
shuffle
);
integrator
.
getPerDofVariable
(
2
,
multicross
);
integrator
.
getPerDofVariable
(
2
,
multicross
);
integrator
.
getPerDofVariable
(
3
,
maxplus
);
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
ASSERT_EQUAL_VEC
(
velocities
[
i
].
cross
(
positions
[
i
]),
angular
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
velocities
[
i
].
cross
(
positions
[
i
]),
angular
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
1
,
1
,
1
)
*
velocities
[
i
].
dot
(
Vec3
(
positions
[
i
][
2
],
positions
[
i
][
0
],
positions
[
i
][
1
])),
shuffle
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
1
,
1
,
1
)
*
velocities
[
i
].
dot
(
Vec3
(
positions
[
i
][
2
],
positions
[
i
][
0
],
positions
[
i
][
1
])),
shuffle
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
1
),
multicross
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
0
,
0
,
1
),
multicross
[
i
],
1e-5
);
ASSERT_EQUAL_VEC
(
Vec3
(
max
(
positions
[
i
][
0
],
0.1
)
+
0.5
,
max
(
positions
[
i
][
1
],
0.1
)
+
0.5
,
max
(
positions
[
i
][
2
],
0.1
)
+
0.5
),
maxplus
[
i
],
1e-5
);
sumy
+=
positions
[
i
][
1
];
sumy
+=
positions
[
i
][
1
];
}
}
ASSERT_EQUAL_TOL
(
sumy
,
integrator
.
getGlobalVariable
(
0
),
1e-5
);
ASSERT_EQUAL_TOL
(
sumy
,
integrator
.
getGlobalVariable
(
0
),
1e-5
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment