Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
46098e35
Commit
46098e35
authored
Jul 28, 2020
by
peastman
Browse files
States can save integrator parameters
parent
694c3930
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
0 deletions
+94
-0
platforms/common/include/openmm/common/CommonKernels.h
platforms/common/include/openmm/common/CommonKernels.h
+16
-0
platforms/common/src/CommonKernels.cpp
platforms/common/src/CommonKernels.cpp
+52
-0
platforms/reference/include/ReferenceKernels.h
platforms/reference/include/ReferenceKernels.h
+16
-0
platforms/reference/src/ReferenceKernels.cpp
platforms/reference/src/ReferenceKernels.cpp
+10
-0
No files found.
platforms/common/include/openmm/common/CommonKernels.h
View file @
46098e35
...
@@ -1017,6 +1017,22 @@ public:
...
@@ -1017,6 +1017,22 @@ public:
* Load the chain states from a checkpoint.
* Load the chain states from a checkpoint.
*/
*/
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void
getChainStates
(
ContextImpl
&
context
,
std
::
vector
<
std
::
vector
<
double
>
>&
positions
,
std
::
vector
<
std
::
vector
<
double
>
>&
velocities
)
const
;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void
setChainStates
(
ContextImpl
&
context
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
positions
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
velocities
);
private:
private:
ComputeContext
&
cc
;
ComputeContext
&
cc
;
float
prevMaxPairDistance
;
float
prevMaxPairDistance
;
...
...
platforms/common/src/CommonKernels.cpp
View file @
46098e35
...
@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i
...
@@ -6255,6 +6255,58 @@ void CommonIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context, i
}
}
}
}
void
CommonIntegrateNoseHooverStepKernel
::
getChainStates
(
ContextImpl
&
context
,
vector
<
vector
<
double
>
>&
positions
,
vector
<
vector
<
double
>
>&
velocities
)
const
{
int
numChains
=
chainState
.
size
();
bool
useDouble
=
cc
.
getUseDoublePrecision
()
||
cc
.
getUseMixedPrecision
();
positions
.
clear
();
velocities
.
clear
();
positions
.
resize
(
numChains
);
velocities
.
resize
(
numChains
);
for
(
int
i
=
0
;
i
<
numChains
;
i
++
)
{
const
ComputeArray
&
state
=
chainState
.
at
(
i
);
if
(
useDouble
)
{
vector
<
mm_double2
>
stateVec
;
state
.
download
(
stateVec
);
for
(
int
j
=
0
;
j
<
stateVec
.
size
();
j
++
)
{
positions
[
i
].
push_back
(
stateVec
[
i
].
x
);
velocities
[
i
].
push_back
(
stateVec
[
i
].
y
);
}
}
else
{
vector
<
mm_float2
>
stateVec
;
state
.
download
(
stateVec
);
for
(
int
j
=
0
;
j
<
stateVec
.
size
();
j
++
)
{
positions
[
i
].
push_back
((
float
)
stateVec
[
i
].
x
);
velocities
[
i
].
push_back
((
float
)
stateVec
[
i
].
y
);
}
}
}
}
void
CommonIntegrateNoseHooverStepKernel
::
setChainStates
(
ContextImpl
&
context
,
const
vector
<
vector
<
double
>
>&
positions
,
const
vector
<
vector
<
double
>
>&
velocities
)
{
int
numChains
=
chainState
.
size
();
bool
useDouble
=
cc
.
getUseDoublePrecision
()
||
cc
.
getUseMixedPrecision
();
if
(
positions
.
size
()
!=
numChains
||
velocities
.
size
()
!=
numChains
)
throw
OpenMMException
(
"setChainStates(): wrong number of chains"
);
for
(
int
i
=
0
;
i
<
numChains
;
i
++
)
{
ComputeArray
&
state
=
chainState
[
i
];
if
(
positions
[
i
].
size
()
!=
state
.
getSize
()
||
velocities
[
i
].
size
()
!=
state
.
getSize
())
throw
OpenMMException
(
"setChainStates(): wrong number of beads in chain"
);
if
(
useDouble
)
{
vector
<
mm_double2
>
stateVec
;
for
(
int
j
=
0
;
j
<
state
.
getSize
();
j
++
)
stateVec
.
push_back
(
mm_double2
(
positions
[
i
][
j
],
velocities
[
i
][
j
]));
state
.
upload
(
stateVec
);
}
else
{
vector
<
mm_float2
>
stateVec
;
for
(
int
j
=
0
;
j
<
state
.
getSize
();
j
++
)
stateVec
.
push_back
(
mm_float2
((
float
)
positions
[
i
][
j
],
(
float
)
velocities
[
i
][
j
]));
state
.
upload
(
stateVec
);
}
}
}
void
CommonIntegrateBrownianStepKernel
::
initialize
(
const
System
&
system
,
const
BrownianIntegrator
&
integrator
)
{
void
CommonIntegrateBrownianStepKernel
::
initialize
(
const
System
&
system
,
const
BrownianIntegrator
&
integrator
)
{
cc
.
initializeContexts
();
cc
.
initializeContexts
();
cc
.
setAsCurrent
();
cc
.
setAsCurrent
();
...
...
platforms/reference/include/ReferenceKernels.h
View file @
46098e35
...
@@ -1211,6 +1211,22 @@ public:
...
@@ -1211,6 +1211,22 @@ public:
* Load the chain states from a checkpoint.
* Load the chain states from a checkpoint.
*/
*/
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
void
loadCheckpoint
(
ContextImpl
&
context
,
std
::
istream
&
stream
);
/**
* Get the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void
getChainStates
(
ContextImpl
&
context
,
std
::
vector
<
std
::
vector
<
double
>
>&
positions
,
std
::
vector
<
std
::
vector
<
double
>
>&
velocities
)
const
;
/**
* Set the internal states of all chains.
*
* @param context the context for which to get the states
* @param positions element [i][j] contains the position of bead j for chain i
* @param velocities element [i][j] contains the velocity of bead j for chain i
*/
void
setChainStates
(
ContextImpl
&
context
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
positions
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
velocities
);
private:
private:
ReferencePlatform
::
PlatformData
&
data
;
ReferencePlatform
::
PlatformData
&
data
;
ReferenceNoseHooverChain
*
chainPropagator
;
ReferenceNoseHooverChain
*
chainPropagator
;
...
...
platforms/reference/src/ReferenceKernels.cpp
View file @
46098e35
...
@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context
...
@@ -2378,6 +2378,16 @@ void ReferenceIntegrateNoseHooverStepKernel::loadCheckpoint(ContextImpl& context
}
}
}
}
void
ReferenceIntegrateNoseHooverStepKernel
::
getChainStates
(
ContextImpl
&
context
,
vector
<
vector
<
double
>
>&
positions
,
vector
<
vector
<
double
>
>&
velocities
)
const
{
positions
=
chainPositions
;
velocities
=
chainVelocities
;
}
void
ReferenceIntegrateNoseHooverStepKernel
::
setChainStates
(
ContextImpl
&
context
,
const
vector
<
vector
<
double
>
>&
positions
,
const
vector
<
vector
<
double
>
>&
velocities
)
{
chainPositions
=
positions
;
chainVelocities
=
velocities
;
}
ReferenceIntegrateLangevinStepKernel
::~
ReferenceIntegrateLangevinStepKernel
()
{
ReferenceIntegrateLangevinStepKernel
::~
ReferenceIntegrateLangevinStepKernel
()
{
if
(
dynamics
)
if
(
dynamics
)
delete
dynamics
;
delete
dynamics
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment