Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
bc05f1c0
Unverified
Commit
bc05f1c0
authored
Sep 02, 2025
by
Emilio Gallicchio
Committed by
GitHub
Sep 01, 2025
Browse files
merge setDisplacements kernel into copyState (#5058)
parent
c91894e8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
66 deletions
+39
-66
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+0
-2
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+4
-19
platforms/common/src/kernels/atmforce.cc
platforms/common/src/kernels/atmforce.cc
+35
-45
No files found.
platforms/common/include/openmm/common/CommonKernels.h
View file @
bc05f1c0
...
@@ -1397,14 +1397,12 @@ private:
...
@@ -1397,14 +1397,12 @@ private:
bool
hasInitializedKernel
;
bool
hasInitializedKernel
;
ComputeContext
&
cc
;
ComputeContext
&
cc
;
ComputeArray
displ1
,
displ0
;
// actual displacements used in calculation
ComputeArray
displacement1
,
displacement0
;
// fixed lab-frame displacements
ComputeArray
displacement1
,
displacement0
;
// fixed lab-frame displacements
ComputeArray
displParticles
;
// variable displacements based on atom positions
ComputeArray
displParticles
;
// variable displacements based on atom positions
// int4 arranged as (pDestination1, pOrigin1, pDestination0, pOrigin0
// int4 arranged as (pDestination1, pOrigin1, pDestination0, pOrigin0
ComputeArray
invAtomOrder
,
inner0InvAtomOrder
,
inner1InvAtomOrder
;
ComputeArray
invAtomOrder
,
inner0InvAtomOrder
,
inner1InvAtomOrder
;
ComputeArray
dforce0
,
dforce1
;
// forces due to variable displacements
ComputeArray
dforce0
,
dforce1
;
// forces due to variable displacements
ComputeKernel
copyStateKernel
;
ComputeKernel
copyStateKernel
;
ComputeKernel
setDisplacementsKernel
;
ComputeKernel
resetDisplForceKernel
;
ComputeKernel
resetDisplForceKernel
;
ComputeKernel
displForceKernel
;
ComputeKernel
displForceKernel
;
ComputeKernel
hybridForceKernel
;
ComputeKernel
hybridForceKernel
;
...
...
platforms/common/src/CommonKernels.cpp
View file @
bc05f1c0
...
@@ -4141,10 +4141,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
...
@@ -4141,10 +4141,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
displVector0
[
p
]
=
mm_double4
(
d0
[
p
][
0
],
d0
[
p
][
1
],
d0
[
p
][
2
],
0
);
displVector0
[
p
]
=
mm_double4
(
d0
[
p
][
0
],
d0
[
p
][
1
],
d0
[
p
][
2
],
0
);
displParticlesVector
[
p
]
=
mm_int4
(
j1
[
p
],
i1
[
p
],
j0
[
p
],
i0
[
p
]);
displParticlesVector
[
p
]
=
mm_int4
(
j1
[
p
],
i1
[
p
],
j0
[
p
],
i0
[
p
]);
}
}
displ1
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displ1"
);
displacement1
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement1"
);
displacement1
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement1"
);
displacement1
.
upload
(
displVector1
);
displacement1
.
upload
(
displVector1
);
displ0
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displ0"
);
displacement0
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement0"
);
displacement0
.
initialize
<
mm_double4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement0"
);
displacement0
.
upload
(
displVector0
);
displacement0
.
upload
(
displVector0
);
}
}
...
@@ -4156,10 +4154,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
...
@@ -4156,10 +4154,8 @@ void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce&
displVector0
[
p
]
=
mm_float4
(
d0
[
p
][
0
],
d0
[
p
][
1
],
d0
[
p
][
2
],
0
);
displVector0
[
p
]
=
mm_float4
(
d0
[
p
][
0
],
d0
[
p
][
1
],
d0
[
p
][
2
],
0
);
displParticlesVector
[
p
]
=
mm_int4
(
j1
[
p
],
i1
[
p
],
j0
[
p
],
i0
[
p
]);
displParticlesVector
[
p
]
=
mm_int4
(
j1
[
p
],
i1
[
p
],
j0
[
p
],
i0
[
p
]);
}
}
displ1
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displ1"
);
displacement1
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement1"
);
displacement1
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement1"
);
displacement1
.
upload
(
displVector1
);
displacement1
.
upload
(
displVector1
);
displ0
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displ0"
);
displacement0
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement0"
);
displacement0
.
initialize
<
mm_float4
>
(
cc
,
cc
.
getPaddedNumAtoms
(),
"displacement0"
);
displacement0
.
upload
(
displVector0
);
displacement0
.
upload
(
displVector0
);
}
}
...
@@ -4203,27 +4199,17 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
...
@@ -4203,27 +4199,17 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
ComputeProgram
program
=
cc
.
compileProgram
(
CommonKernelSources
::
atmforce
);
ComputeProgram
program
=
cc
.
compileProgram
(
CommonKernelSources
::
atmforce
);
//create the setDisplacements kernel
setDisplacementsKernel
=
program
->
createKernel
(
"setDisplacements"
);
setDisplacementsKernel
->
addArg
(
numParticles
);
setDisplacementsKernel
->
addArg
(
cc
.
getPosq
());
setDisplacementsKernel
->
addArg
(
displacement0
);
setDisplacementsKernel
->
addArg
(
displacement1
);
setDisplacementsKernel
->
addArg
(
displParticles
);
setDisplacementsKernel
->
addArg
(
cc
.
getAtomIndexArray
());
setDisplacementsKernel
->
addArg
(
invAtomOrder
);
setDisplacementsKernel
->
addArg
(
displ0
);
setDisplacementsKernel
->
addArg
(
displ1
);
//create CopyState kernel
//create CopyState kernel
copyStateKernel
=
program
->
createKernel
(
"copyState"
);
copyStateKernel
=
program
->
createKernel
(
"copyState"
);
copyStateKernel
->
addArg
(
numParticles
);
copyStateKernel
->
addArg
(
numParticles
);
copyStateKernel
->
addArg
(
cc
.
getPosq
());
copyStateKernel
->
addArg
(
cc
.
getPosq
());
copyStateKernel
->
addArg
(
cc0
.
getPosq
());
copyStateKernel
->
addArg
(
cc0
.
getPosq
());
copyStateKernel
->
addArg
(
cc1
.
getPosq
());
copyStateKernel
->
addArg
(
cc1
.
getPosq
());
copyStateKernel
->
addArg
(
displ0
);
copyStateKernel
->
addArg
(
displacement0
);
copyStateKernel
->
addArg
(
displ1
);
copyStateKernel
->
addArg
(
displacement1
);
copyStateKernel
->
addArg
(
displParticles
);
copyStateKernel
->
addArg
(
cc
.
getAtomIndexArray
());
copyStateKernel
->
addArg
(
cc
.
getAtomIndexArray
());
copyStateKernel
->
addArg
(
invAtomOrder
);
copyStateKernel
->
addArg
(
inner0InvAtomOrder
);
copyStateKernel
->
addArg
(
inner0InvAtomOrder
);
copyStateKernel
->
addArg
(
inner1InvAtomOrder
);
copyStateKernel
->
addArg
(
inner1InvAtomOrder
);
if
(
cc
.
getUseMixedPrecision
())
{
if
(
cc
.
getUseMixedPrecision
())
{
...
@@ -4313,7 +4299,6 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context,
...
@@ -4313,7 +4299,6 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context,
cc0
.
reorderAtoms
();
cc0
.
reorderAtoms
();
cc1
.
reorderAtoms
();
cc1
.
reorderAtoms
();
setDisplacementsKernel
->
execute
(
numParticles
);
copyStateKernel
->
execute
(
numParticles
);
copyStateKernel
->
execute
(
numParticles
);
map
<
string
,
double
>
innerParameters0
=
innerContext0
.
getParameters
();
map
<
string
,
double
>
innerParameters0
=
innerContext0
.
getParameters
();
...
...
platforms/common/src/kernels/atmforce.cc
View file @
bc05f1c0
...
@@ -26,47 +26,6 @@ KERNEL void hybridForce(int numParticles,
...
@@ -26,47 +26,6 @@ KERNEL void hybridForce(int numParticles,
}
}
}
}
KERNEL
void
setDisplacements
(
int
numParticles
,
GLOBAL
real4
*
RESTRICT
posq
,
GLOBAL
real4
*
RESTRICT
displacement0
,
GLOBAL
real4
*
RESTRICT
displacement1
,
GLOBAL
int4
*
displParticles
,
GLOBAL
int
*
RESTRICT
atomOrder
,
GLOBAL
int
*
RESTRICT
invAtomOrder
,
GLOBAL
real4
*
RESTRICT
displ0
,
GLOBAL
real4
*
RESTRICT
displ1
)
{
for
(
int
index
=
GLOBAL_ID
;
index
<
numParticles
;
index
+=
GLOBAL_SIZE
)
{
int
atom
=
atomOrder
[
index
];
int
pj1
=
displParticles
[
atom
].
x
;
int
pi1
=
displParticles
[
atom
].
y
;
int
pj0
=
displParticles
[
atom
].
z
;
int
pi0
=
displParticles
[
atom
].
w
;
if
(
pj1
>=
0
&&
pi1
>=
0
)
{
// variable system coordinate displacements
int
indexj1
=
invAtomOrder
[
pj1
];
int
indexi1
=
invAtomOrder
[
pi1
];
displ1
[
atom
]
=
make_real4
((
real
)
posq
[
indexj1
].
x
-
posq
[
indexi1
].
x
,
(
real
)
posq
[
indexj1
].
y
-
posq
[
indexi1
].
y
,
(
real
)
posq
[
indexj1
].
z
-
posq
[
indexi1
].
z
,
(
real
)
0
);
if
(
pj0
>=
0
&&
pi0
>=
0
)
{
int
indexj0
=
invAtomOrder
[
pj0
];
int
indexi0
=
invAtomOrder
[
pi0
];
displ0
[
atom
]
=
make_real4
((
real
)
posq
[
indexj0
].
x
-
posq
[
indexi0
].
x
,
(
real
)
posq
[
indexj0
].
y
-
posq
[
indexi0
].
y
,
(
real
)
posq
[
indexj0
].
z
-
posq
[
indexi0
].
z
,
(
real
)
0
);
}
else
{
displ0
[
atom
]
=
make_real4
((
real
)
0
,
(
real
)
0
,
(
real
)
0
,
(
real
)
0
);
}
}
else
{
//fixed lab frame displacement
displ1
[
atom
]
=
displacement1
[
atom
];
displ0
[
atom
]
=
displacement0
[
atom
];
}
}
}
//reset variable displacement forces
//reset variable displacement forces
KERNEL
void
resetDisplForce
(
int
numParticles
,
KERNEL
void
resetDisplForce
(
int
numParticles
,
int
paddedNumParticles
,
int
paddedNumParticles
,
...
@@ -134,9 +93,11 @@ KERNEL void copyState(int numParticles,
...
@@ -134,9 +93,11 @@ KERNEL void copyState(int numParticles,
GLOBAL
real4
*
RESTRICT
posq
,
GLOBAL
real4
*
RESTRICT
posq
,
GLOBAL
real4
*
RESTRICT
posq0
,
GLOBAL
real4
*
RESTRICT
posq0
,
GLOBAL
real4
*
RESTRICT
posq1
,
GLOBAL
real4
*
RESTRICT
posq1
,
GLOBAL
real4
*
RESTRICT
displ0
,
GLOBAL
real4
*
RESTRICT
displacement0
,
GLOBAL
real4
*
RESTRICT
displ1
,
GLOBAL
real4
*
RESTRICT
displacement1
,
GLOBAL
int4
*
displParticles
,
GLOBAL
int
*
RESTRICT
atomOrder
,
GLOBAL
int
*
RESTRICT
atomOrder
,
GLOBAL
int
*
RESTRICT
invAtomOrder
,
GLOBAL
int
*
RESTRICT
inner0InvAtomOrder
,
GLOBAL
int
*
RESTRICT
inner0InvAtomOrder
,
GLOBAL
int
*
RESTRICT
inner1InvAtomOrder
GLOBAL
int
*
RESTRICT
inner1InvAtomOrder
#ifdef USE_MIXED_PRECISION
#ifdef USE_MIXED_PRECISION
...
@@ -146,12 +107,41 @@ KERNEL void copyState(int numParticles,
...
@@ -146,12 +107,41 @@ KERNEL void copyState(int numParticles,
GLOBAL
real4
*
RESTRICT
posq1Correction
GLOBAL
real4
*
RESTRICT
posq1Correction
#endif
#endif
)
{
)
{
for
(
int
i
=
GLOBAL_ID
;
i
<
numParticles
;
i
+=
GLOBAL_SIZE
)
{
for
(
int
i
=
GLOBAL_ID
;
i
<
numParticles
;
i
+=
GLOBAL_SIZE
)
{
int
atom
=
atomOrder
[
i
];
int
atom
=
atomOrder
[
i
];
//default fixed lab frame displacement
real4
displ0
=
displacement0
[
atom
];
real4
displ1
=
displacement1
[
atom
];
//override with variable displacements if set
int
pj1
=
displParticles
[
atom
].
x
;
int
pi1
=
displParticles
[
atom
].
y
;
int
pj0
=
displParticles
[
atom
].
z
;
int
pi0
=
displParticles
[
atom
].
w
;
if
(
pj1
>=
0
&&
pi1
>=
0
)
{
// variable system coordinate displacements
int
indexj1
=
invAtomOrder
[
pj1
];
int
indexi1
=
invAtomOrder
[
pi1
];
displ1
=
make_real4
((
real
)
posq
[
indexj1
].
x
-
posq
[
indexi1
].
x
,
(
real
)
posq
[
indexj1
].
y
-
posq
[
indexi1
].
y
,
(
real
)
posq
[
indexj1
].
z
-
posq
[
indexi1
].
z
,
(
real
)
0
);
if
(
pj0
>=
0
&&
pi0
>=
0
)
{
int
indexj0
=
invAtomOrder
[
pj0
];
int
indexi0
=
invAtomOrder
[
pi0
];
displ0
=
make_real4
((
real
)
posq
[
indexj0
].
x
-
posq
[
indexi0
].
x
,
(
real
)
posq
[
indexj0
].
y
-
posq
[
indexi0
].
y
,
(
real
)
posq
[
indexj0
].
z
-
posq
[
indexi0
].
z
,
(
real
)
0
);
}
else
{
displ0
=
make_real4
((
real
)
0
,
(
real
)
0
,
(
real
)
0
,
(
real
)
0
);
}
}
int
index0
=
inner0InvAtomOrder
[
atom
];
int
index0
=
inner0InvAtomOrder
[
atom
];
int
index1
=
inner1InvAtomOrder
[
atom
];
int
index1
=
inner1InvAtomOrder
[
atom
];
real4
p0
=
posq
[
i
]
+
make_real4
((
real
)
displ0
[
atom
]
.
x
,
(
real
)
displ0
[
atom
]
.
y
,
(
real
)
displ0
[
atom
]
.
z
,
0
);
real4
p0
=
posq
[
i
]
+
make_real4
((
real
)
displ0
.
x
,
(
real
)
displ0
.
y
,
(
real
)
displ0
.
z
,
0
);
real4
p1
=
posq
[
i
]
+
make_real4
((
real
)
displ1
[
atom
]
.
x
,
(
real
)
displ1
[
atom
]
.
y
,
(
real
)
displ1
[
atom
]
.
z
,
0
);
real4
p1
=
posq
[
i
]
+
make_real4
((
real
)
displ1
.
x
,
(
real
)
displ1
.
y
,
(
real
)
displ1
.
z
,
0
);
p0
.
w
=
posq0
[
i
].
w
;
p0
.
w
=
posq0
[
i
].
w
;
p1
.
w
=
posq1
[
i
].
w
;
p1
.
w
=
posq1
[
i
].
w
;
posq0
[
index0
]
=
p0
;
posq0
[
index0
]
=
p0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment