Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
88f32f2d
Unverified
Commit
88f32f2d
authored
May 25, 2025
by
Peter Eastman
Committed by
GitHub
May 25, 2025
Browse files
Optimize computing kinetic energy (#4946)
parent
f19c9f59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
22 deletions
+47
-22
platforms/common/include/openmm/common/IntegrationUtilities.h
...forms/common/include/openmm/common/IntegrationUtilities.h
+3
-2
platforms/common/src/IntegrationUtilities.cpp
platforms/common/src/IntegrationUtilities.cpp
+23
-20
platforms/common/src/kernels/integrationUtilities.cc
platforms/common/src/kernels/integrationUtilities.cc
+21
-0
No files found.
platforms/common/include/openmm/common/IntegrationUtilities.h
View file @
88f32f2d
...
@@ -145,7 +145,7 @@ protected:
...
@@ -145,7 +145,7 @@ protected:
ComputeKernel
ccmaDirectionsKernel
,
ccmaPosForceKernel
,
ccmaVelForceKernel
;
ComputeKernel
ccmaDirectionsKernel
,
ccmaPosForceKernel
,
ccmaVelForceKernel
;
ComputeKernel
ccmaMultiplyKernel
,
ccmaUpdateKernel
,
ccmaFullKernel
;
ComputeKernel
ccmaMultiplyKernel
,
ccmaUpdateKernel
,
ccmaFullKernel
;
ComputeKernel
vsitePositionKernel
,
vsiteForceKernel
,
vsiteSaveForcesKernel
;
ComputeKernel
vsitePositionKernel
,
vsiteForceKernel
,
vsiteSaveForcesKernel
;
ComputeKernel
randomKernel
,
timeShiftKernel
;
ComputeKernel
randomKernel
,
timeShiftKernel
,
kineticEnergyKernel
;
ComputeArray
posDelta
;
ComputeArray
posDelta
;
ComputeArray
settleAtoms
;
ComputeArray
settleAtoms
;
ComputeArray
settleParams
;
ComputeArray
settleParams
;
...
@@ -177,7 +177,8 @@ protected:
...
@@ -177,7 +177,8 @@ protected:
ComputeArray
vsiteLocalCoordsPos
;
ComputeArray
vsiteLocalCoordsPos
;
ComputeArray
vsiteLocalCoordsStartIndex
;
ComputeArray
vsiteLocalCoordsStartIndex
;
ComputeArray
vsiteStage
;
ComputeArray
vsiteStage
;
int
randomPos
,
lastSeed
,
numVsites
,
numVsiteStages
;
ComputeArray
kineticEnergy
;
int
randomPos
,
lastSeed
,
numVsites
,
numVsiteStages
,
keWorkGroupSize
;
bool
hasOverlappingVsites
;
bool
hasOverlappingVsites
;
mm_double2
lastStepSize
;
mm_double2
lastStepSize
;
struct
ShakeCluster
;
struct
ShakeCluster
;
...
...
platforms/common/src/IntegrationUtilities.cpp
View file @
88f32f2d
...
@@ -101,6 +101,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
...
@@ -101,6 +101,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
posDelta
.
upload
(
deltas
);
posDelta
.
upload
(
deltas
);
stepSize
.
initialize
<
mm_double2
>
(
context
,
1
,
"stepSize"
);
stepSize
.
initialize
<
mm_double2
>
(
context
,
1
,
"stepSize"
);
stepSize
.
upload
(
&
lastStepSize
);
stepSize
.
upload
(
&
lastStepSize
);
kineticEnergy
.
initialize
<
double
>
(
context
,
1
,
"kineticEnergy"
);
}
}
else
{
else
{
posDelta
.
initialize
<
mm_float4
>
(
context
,
context
.
getPaddedNumAtoms
(),
"posDelta"
);
posDelta
.
initialize
<
mm_float4
>
(
context
,
context
.
getPaddedNumAtoms
(),
"posDelta"
);
...
@@ -109,7 +110,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
...
@@ -109,7 +110,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
stepSize
.
initialize
<
mm_float2
>
(
context
,
1
,
"stepSize"
);
stepSize
.
initialize
<
mm_float2
>
(
context
,
1
,
"stepSize"
);
mm_float2
lastStepSizeFloat
=
mm_float2
(
0.0
f
,
0.0
f
);
mm_float2
lastStepSizeFloat
=
mm_float2
(
0.0
f
,
0.0
f
);
stepSize
.
upload
(
&
lastStepSizeFloat
);
stepSize
.
upload
(
&
lastStepSizeFloat
);
kineticEnergy
.
initialize
<
float
>
(
context
,
1
,
"kineticEnergy"
);
}
}
keWorkGroupSize
=
context
.
getMaxThreadBlockSize
();
if
(
keWorkGroupSize
>
512
)
keWorkGroupSize
=
512
;
// Record the set of constraints and how many constraints each atom is involved in.
// Record the set of constraints and how many constraints each atom is involved in.
...
@@ -573,6 +578,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
...
@@ -573,6 +578,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
defines
[
"NUM_OUT_OF_PLANE"
]
=
context
.
intToString
(
numOutOfPlane
);
defines
[
"NUM_OUT_OF_PLANE"
]
=
context
.
intToString
(
numOutOfPlane
);
defines
[
"NUM_LOCAL_COORDS"
]
=
context
.
intToString
(
numLocalCoords
);
defines
[
"NUM_LOCAL_COORDS"
]
=
context
.
intToString
(
numLocalCoords
);
defines
[
"PADDED_NUM_ATOMS"
]
=
context
.
intToString
(
context
.
getPaddedNumAtoms
());
defines
[
"PADDED_NUM_ATOMS"
]
=
context
.
intToString
(
context
.
getPaddedNumAtoms
());
defines
[
"KE_WORK_GROUP_SIZE"
]
=
context
.
intToString
(
keWorkGroupSize
);
if
(
hasOverlappingVsites
)
if
(
hasOverlappingVsites
)
defines
[
"HAS_OVERLAPPING_VSITES"
]
=
"1"
;
defines
[
"HAS_OVERLAPPING_VSITES"
]
=
"1"
;
if
(
numVsiteStages
>
1
)
if
(
numVsiteStages
>
1
)
...
@@ -593,6 +599,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
...
@@ -593,6 +599,7 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
vsiteSaveForcesKernel
=
program
->
createKernel
(
"saveDistributedForces"
);
vsiteSaveForcesKernel
=
program
->
createKernel
(
"saveDistributedForces"
);
randomKernel
=
program
->
createKernel
(
"generateRandomNumbers"
);
randomKernel
=
program
->
createKernel
(
"generateRandomNumbers"
);
timeShiftKernel
=
program
->
createKernel
(
"timeShiftVelocities"
);
timeShiftKernel
=
program
->
createKernel
(
"timeShiftVelocities"
);
kineticEnergyKernel
=
program
->
createKernel
(
"computeKineticEnergy"
);
// Set arguments for virtual site kernels.
// Set arguments for virtual site kernels.
...
@@ -740,6 +747,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
...
@@ -740,6 +747,11 @@ IntegrationUtilities::IntegrationUtilities(ComputeContext& context, const System
for
(
int
i
=
0
;
i
<
3
;
i
++
)
for
(
int
i
=
0
;
i
<
3
;
i
++
)
timeShiftKernel
->
addArg
();
timeShiftKernel
->
addArg
();
// Set arguments of kinetic energy kernel.
kineticEnergyKernel
->
addArg
(
context
.
getVelm
());
kineticEnergyKernel
->
addArg
(
kineticEnergy
);
}
}
void
IntegrationUtilities
::
setNextStepSize
(
double
size
)
{
void
IntegrationUtilities
::
setNextStepSize
(
double
size
)
{
...
@@ -874,31 +886,22 @@ double IntegrationUtilities::computeKineticEnergy(double timeShift) {
...
@@ -874,31 +886,22 @@ double IntegrationUtilities::computeKineticEnergy(double timeShift) {
// Compute the kinetic energy.
// Compute the kinetic energy.
double
energy
=
0.0
;
kineticEnergyKernel
->
execute
(
keWorkGroupSize
,
keWorkGroupSize
);
if
(
context
.
getUseDoublePrecision
()
||
context
.
getUseMixedPrecision
())
{
auto
velm
=
(
mm_double4
*
)
context
.
getPinnedBuffer
();
context
.
getVelm
().
download
(
velm
);
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
mm_double4
v
=
velm
[
i
];
if
(
v
.
w
!=
0
)
energy
+=
(
v
.
x
*
v
.
x
+
v
.
y
*
v
.
y
+
v
.
z
*
v
.
z
)
/
v
.
w
;
}
}
else
{
auto
velm
=
(
mm_float4
*
)
context
.
getPinnedBuffer
();
context
.
getVelm
().
download
(
velm
);
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
mm_float4
v
=
velm
[
i
];
if
(
v
.
w
!=
0
)
energy
+=
(
v
.
x
*
v
.
x
+
v
.
y
*
v
.
y
+
v
.
z
*
v
.
z
)
/
v
.
w
;
}
}
// Restore the velocities.
// Restore the velocities.
if
(
timeShift
!=
0
)
if
(
timeShift
!=
0
)
posDelta
.
copyTo
(
context
.
getVelm
());
posDelta
.
copyTo
(
context
.
getVelm
());
return
0.5
*
energy
;
if
(
context
.
getUseDoublePrecision
()
||
context
.
getUseMixedPrecision
())
{
double
energy
;
kineticEnergy
.
download
(
&
energy
);
return
energy
;
}
else
{
float
energy
;
kineticEnergy
.
download
(
&
energy
);
return
energy
;
}
}
}
void
IntegrationUtilities
::
computeShiftedVelocities
(
double
timeShift
,
vector
<
Vec3
>&
velocities
)
{
void
IntegrationUtilities
::
computeShiftedVelocities
(
double
timeShift
,
vector
<
Vec3
>&
velocities
)
{
...
...
platforms/common/src/kernels/integrationUtilities.cc
View file @
88f32f2d
...
@@ -1075,3 +1075,24 @@ KERNEL void timeShiftVelocities(GLOBAL mixed4* RESTRICT velm, GLOBAL const mm_lo
...
@@ -1075,3 +1075,24 @@ KERNEL void timeShiftVelocities(GLOBAL mixed4* RESTRICT velm, GLOBAL const mm_lo
}
}
}
}
}
}
/**
* Compute the total kinetic energy.
*/
KERNEL
void
computeKineticEnergy
(
GLOBAL
mixed4
*
RESTRICT
velm
,
GLOBAL
mixed
*
result
)
{
LOCAL
mixed
tempBuffer
[
KE_WORK_GROUP_SIZE
];
mixed
sum
=
0
;
for
(
unsigned
int
index
=
LOCAL_ID
;
index
<
NUM_ATOMS
;
index
+=
LOCAL_SIZE
)
{
mixed4
v
=
velm
[
index
];
if
(
v
.
w
!=
0
)
sum
+=
(
v
.
x
*
v
.
x
+
v
.
y
*
v
.
y
+
v
.
z
*
v
.
z
)
/
v
.
w
;
}
tempBuffer
[
LOCAL_ID
]
=
sum
;
for
(
int
i
=
1
;
i
<
KE_WORK_GROUP_SIZE
;
i
*=
2
)
{
SYNC_THREADS
;
if
(
LOCAL_ID
%
(
i
*
2
)
==
0
&&
LOCAL_ID
+
i
<
KE_WORK_GROUP_SIZE
)
tempBuffer
[
LOCAL_ID
]
+=
tempBuffer
[
LOCAL_ID
+
i
];
}
if
(
LOCAL_ID
==
0
)
*
result
=
0.5
f
*
tempBuffer
[
0
];
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment