Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
559da024
Unverified
Commit
559da024
authored
May 23, 2025
by
Peter Eastman
Committed by
GitHub
May 23, 2025
Browse files
Optimized setPositions() and setVelocities() (#4945)
* Optimized setPositions() and setVelocities() * Fix test failures
parent
7df74a1c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
36 deletions
+65
-36
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+2
-0
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+43
-36
platforms/common/src/kernels/copyCoordinateBuffers.cc
platforms/common/src/kernels/copyCoordinateBuffers.cc
+17
-0
platforms/cuda/src/CudaContext.cpp
platforms/cuda/src/CudaContext.cpp
+1
-0
platforms/hip/src/HipContext.cpp
platforms/hip/src/HipContext.cpp
+1
-0
platforms/opencl/src/OpenCLContext.cpp
platforms/opencl/src/OpenCLContext.cpp
+1
-0
No files found.
platforms/common/include/openmm/common/CommonKernels.h
View file @
559da024
...
@@ -154,6 +154,8 @@ public:
...
@@ -154,6 +154,8 @@ public:
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
private:
private:
ComputeContext
&
cc
;
ComputeContext
&
cc
;
ComputeArray
floatBuffer
,
doubleBuffer
;
ComputeKernel
copyFloatKernel
,
copyDoubleKernel
;
};
};
/**
/**
...
...
platforms/common/src/CommonKernels.cpp
View file @
559da024
...
@@ -58,6 +58,21 @@ using namespace std;
...
@@ -58,6 +58,21 @@ using namespace std;
using
namespace
Lepton
;
using
namespace
Lepton
;
void
CommonUpdateStateDataKernel
::
initialize
(
const
System
&
system
)
{
void
CommonUpdateStateDataKernel
::
initialize
(
const
System
&
system
)
{
ContextSelector
selector
(
cc
);
floatBuffer
.
initialize
<
float
>
(
cc
,
3
*
system
.
getNumParticles
(),
"floatBuffer"
);
map
<
string
,
string
>
defines
;
ComputeProgram
program
=
cc
.
compileProgram
(
CommonKernelSources
::
copyCoordinateBuffers
,
defines
);
copyFloatKernel
=
program
->
createKernel
(
"copyFloatBuffer"
);
copyFloatKernel
->
addArg
(
floatBuffer
);
copyFloatKernel
->
addArg
();
copyFloatKernel
->
addArg
(
cc
.
getNumAtoms
());
if
(
cc
.
getUseMixedPrecision
()
||
cc
.
getUseDoublePrecision
())
{
doubleBuffer
.
initialize
<
double
>
(
cc
,
3
*
system
.
getNumParticles
(),
"doubleBuffer"
);
copyDoubleKernel
=
program
->
createKernel
(
"copyDoubleBuffer"
);
copyDoubleKernel
->
addArg
(
doubleBuffer
);
copyDoubleKernel
->
addArg
();
copyDoubleKernel
->
addArg
(
cc
.
getNumAtoms
());
}
}
}
double
CommonUpdateStateDataKernel
::
getTime
(
const
ContextImpl
&
context
)
const
{
double
CommonUpdateStateDataKernel
::
getTime
(
const
ContextImpl
&
context
)
const
{
...
@@ -144,32 +159,28 @@ void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vecto
...
@@ -144,32 +159,28 @@ void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vecto
const
vector
<
int
>&
order
=
cc
.
getAtomIndex
();
const
vector
<
int
>&
order
=
cc
.
getAtomIndex
();
int
numParticles
=
context
.
getSystem
().
getNumParticles
();
int
numParticles
=
context
.
getSystem
().
getNumParticles
();
if
(
cc
.
getUseDoublePrecision
())
{
if
(
cc
.
getUseDoublePrecision
())
{
mm_double4
*
posq
=
(
mm_double4
*
)
cc
.
getPinnedBuffer
();
double
*
pos
=
(
double
*
)
cc
.
getPinnedBuffer
();
cc
.
getPosq
().
download
(
posq
);
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
mm_double4
&
pos
=
posq
[
i
];
const
Vec3
&
p
=
positions
[
order
[
i
]];
const
Vec3
&
p
=
positions
[
order
[
i
]];
pos
.
x
=
p
[
0
];
pos
[
3
*
i
]
=
p
[
0
];
pos
.
y
=
p
[
1
];
pos
[
3
*
i
+
1
]
=
p
[
1
];
pos
.
z
=
p
[
2
];
pos
[
3
*
i
+
2
]
=
p
[
2
];
}
}
for
(
int
i
=
numParticles
;
i
<
cc
.
getPaddedNumAtoms
();
i
++
)
doubleBuffer
.
upload
(
pos
);
posq
[
i
]
=
mm_double4
(
0.0
,
0.0
,
0.0
,
0.0
);
copyDoubleKernel
->
setArg
(
1
,
cc
.
getPosq
()
);
c
c
.
getPosq
().
upload
(
posq
);
c
opyDoubleKernel
->
execute
(
numParticles
);
}
}
else
{
else
{
mm_float4
*
posq
=
(
mm_float4
*
)
cc
.
getPinnedBuffer
();
float
*
pos
=
(
float
*
)
cc
.
getPinnedBuffer
();
cc
.
getPosq
().
download
(
posq
);
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
mm_float4
&
pos
=
posq
[
i
];
const
Vec3
&
p
=
positions
[
order
[
i
]];
const
Vec3
&
p
=
positions
[
order
[
i
]];
pos
.
x
=
(
float
)
p
[
0
];
pos
[
3
*
i
]
=
(
float
)
p
[
0
];
pos
.
y
=
(
float
)
p
[
1
];
pos
[
3
*
i
+
1
]
=
(
float
)
p
[
1
];
pos
.
z
=
(
float
)
p
[
2
];
pos
[
3
*
i
+
2
]
=
(
float
)
p
[
2
];
}
}
f
or
(
int
i
=
numParticles
;
i
<
cc
.
getPaddedNumAtoms
();
i
++
)
f
loatBuffer
.
upload
(
pos
);
posq
[
i
]
=
mm_float4
(
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
);
copyFloatKernel
->
setArg
(
1
,
cc
.
getPosq
()
);
c
c
.
getPosq
().
upload
(
posq
);
c
opyFloatKernel
->
execute
(
numParticles
);
}
}
if
(
cc
.
getUseMixedPrecision
())
{
if
(
cc
.
getUseMixedPrecision
())
{
mm_float4
*
posCorrection
=
(
mm_float4
*
)
cc
.
getPinnedBuffer
();
mm_float4
*
posCorrection
=
(
mm_float4
*
)
cc
.
getPinnedBuffer
();
...
@@ -218,32 +229,28 @@ void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vect
...
@@ -218,32 +229,28 @@ void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vect
const
vector
<
int
>&
order
=
cc
.
getAtomIndex
();
const
vector
<
int
>&
order
=
cc
.
getAtomIndex
();
int
numParticles
=
context
.
getSystem
().
getNumParticles
();
int
numParticles
=
context
.
getSystem
().
getNumParticles
();
if
(
cc
.
getUseDoublePrecision
()
||
cc
.
getUseMixedPrecision
())
{
if
(
cc
.
getUseDoublePrecision
()
||
cc
.
getUseMixedPrecision
())
{
mm_double4
*
velm
=
(
mm_double4
*
)
cc
.
getPinnedBuffer
();
double
*
vel
=
(
double
*
)
cc
.
getPinnedBuffer
();
cc
.
getVelm
().
download
(
velm
);
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
mm_double4
&
vel
=
velm
[
i
];
const
Vec3
&
p
=
velocities
[
order
[
i
]];
const
Vec3
&
p
=
velocities
[
order
[
i
]];
vel
.
x
=
p
[
0
];
vel
[
3
*
i
]
=
p
[
0
];
vel
.
y
=
p
[
1
];
vel
[
3
*
i
+
1
]
=
p
[
1
];
vel
.
z
=
p
[
2
];
vel
[
3
*
i
+
2
]
=
p
[
2
];
}
}
for
(
int
i
=
numParticles
;
i
<
cc
.
getPaddedNumAtoms
();
i
++
)
doubleBuffer
.
upload
(
vel
);
velm
[
i
]
=
mm_double4
(
0.0
,
0.0
,
0.0
,
0.0
);
copyDoubleKernel
->
setArg
(
1
,
cc
.
getVelm
()
);
c
c
.
getVelm
().
upload
(
velm
);
c
opyDoubleKernel
->
execute
(
numParticles
);
}
}
else
{
else
{
mm_float4
*
velm
=
(
mm_float4
*
)
cc
.
getPinnedBuffer
();
float
*
vel
=
(
float
*
)
cc
.
getPinnedBuffer
();
cc
.
getVelm
().
download
(
velm
);
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
++
i
)
{
mm_float4
&
vel
=
velm
[
i
];
const
Vec3
&
p
=
velocities
[
order
[
i
]];
const
Vec3
&
p
=
velocities
[
order
[
i
]];
vel
.
x
=
p
[
0
];
vel
[
3
*
i
]
=
(
float
)
p
[
0
];
vel
.
y
=
p
[
1
];
vel
[
3
*
i
+
1
]
=
(
float
)
p
[
1
];
vel
.
z
=
p
[
2
];
vel
[
3
*
i
+
2
]
=
(
float
)
p
[
2
];
}
}
f
or
(
int
i
=
numParticles
;
i
<
cc
.
getPaddedNumAtoms
();
i
++
)
f
loatBuffer
.
upload
(
vel
);
velm
[
i
]
=
mm_float4
(
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
);
copyFloatKernel
->
setArg
(
1
,
cc
.
getVelm
()
);
c
c
.
getVelm
().
upload
(
velm
);
c
opyFloatKernel
->
execute
(
numParticles
);
}
}
}
}
...
...
platforms/common/src/kernels/copyCoordinateBuffers.cc
0 → 100644
View file @
559da024
KERNEL
void
copyFloatBuffer
(
GLOBAL
float
*
RESTRICT
source
,
GLOBAL
float4
*
RESTRICT
dest
,
int
numAtoms
)
{
for
(
int
i
=
GLOBAL_ID
;
i
<
numAtoms
;
i
+=
GLOBAL_SIZE
)
{
dest
[
i
].
x
=
source
[
3
*
i
];
dest
[
i
].
y
=
source
[
3
*
i
+
1
];
dest
[
i
].
z
=
source
[
3
*
i
+
2
];
}
}
#ifdef SUPPORTS_DOUBLE_PRECISION
KERNEL
void
copyDoubleBuffer
(
GLOBAL
double
*
RESTRICT
source
,
GLOBAL
double4
*
RESTRICT
dest
,
int
numAtoms
)
{
for
(
int
i
=
GLOBAL_ID
;
i
<
numAtoms
;
i
+=
GLOBAL_SIZE
)
{
dest
[
i
].
x
=
source
[
3
*
i
];
dest
[
i
].
y
=
source
[
3
*
i
+
1
];
dest
[
i
].
z
=
source
[
3
*
i
+
2
];
}
}
#endif
\ No newline at end of file
platforms/cuda/src/CudaContext.cpp
View file @
559da024
...
@@ -359,6 +359,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
...
@@ -359,6 +359,7 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
nonbonded
=
new
CudaNonbondedUtilities
(
*
this
);
nonbonded
=
new
CudaNonbondedUtilities
(
*
this
);
integration
=
new
CudaIntegrationUtilities
(
*
this
,
system
);
integration
=
new
CudaIntegrationUtilities
(
*
this
,
system
);
expression
=
new
CudaExpressionUtilities
(
*
this
);
expression
=
new
CudaExpressionUtilities
(
*
this
);
clearBuffer
(
posq
);
}
}
CudaContext
::~
CudaContext
()
{
CudaContext
::~
CudaContext
()
{
...
...
platforms/hip/src/HipContext.cpp
View file @
559da024
...
@@ -351,6 +351,7 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy
...
@@ -351,6 +351,7 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy
nonbonded
=
new
HipNonbondedUtilities
(
*
this
);
nonbonded
=
new
HipNonbondedUtilities
(
*
this
);
integration
=
new
HipIntegrationUtilities
(
*
this
,
system
);
integration
=
new
HipIntegrationUtilities
(
*
this
,
system
);
expression
=
new
HipExpressionUtilities
(
*
this
);
expression
=
new
HipExpressionUtilities
(
*
this
);
clearBuffer
(
posq
);
}
}
HipContext
::~
HipContext
()
{
HipContext
::~
HipContext
()
{
...
...
platforms/opencl/src/OpenCLContext.cpp
View file @
559da024
...
@@ -492,6 +492,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
...
@@ -492,6 +492,7 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
nonbonded
=
new
OpenCLNonbondedUtilities
(
*
this
);
nonbonded
=
new
OpenCLNonbondedUtilities
(
*
this
);
integration
=
new
OpenCLIntegrationUtilities
(
*
this
,
system
);
integration
=
new
OpenCLIntegrationUtilities
(
*
this
,
system
);
expression
=
new
OpenCLExpressionUtilities
(
*
this
);
expression
=
new
OpenCLExpressionUtilities
(
*
this
);
clearBuffer
(
posq
);
}
}
OpenCLContext
::~
OpenCLContext
()
{
OpenCLContext
::~
OpenCLContext
()
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment