Commit 4f858162 authored by Robert McGibbon's avatar Robert McGibbon
Browse files

Merge branch 'master' into checkpoint-reporter

parents b4a98238 d3d352c1
......@@ -39,6 +39,7 @@ from simtk.openmm.app.topology import Residue
from simtk.openmm.vec3 import Vec3
from simtk.openmm import System, Context, NonbondedForce, CustomNonbondedForce, HarmonicBondForce, HarmonicAngleForce, VerletIntegrator, LocalEnergyMinimizer
from simtk.unit import nanometer, molar, elementary_charge, amu, gram, liter, degree, sqrt, acos, is_quantity, dot, norm
import simtk.unit as unit
import element as elem
import os
import random
......@@ -843,7 +844,9 @@ class Modeller(object):
if atom.element is not None:
newIndex[i] = index
index += 1
newTemplate.atoms.append(ForceField._TemplateAtomData(atom.name, atom.type, atom.element))
newAtom = ForceField._TemplateAtomData(atom.name, atom.type, atom.element)
newAtom.externalBonds = atom.externalBonds
newTemplate.atoms.append(newAtom)
for b1, b2 in template.bonds:
if b1 in newIndex and b2 in newIndex:
newTemplate.bonds.append((newIndex[b1], newIndex[b2]))
......@@ -968,7 +971,7 @@ class Modeller(object):
# and hope that energy minimization will fix it.
knownPositions = [x for x in templateAtomPositions if x is not None]
position = sum(knownPositions)/len(knownPositions)
position = unit.sum(knownPositions)/len(knownPositions)
newPositions.append(position*nanometer)
for bond in self.topology.bonds():
if bond[0] in newAtoms and bond[1] in newAtoms:
......
......@@ -201,9 +201,12 @@ class StateDataReporter(object):
if self._density:
values.append((self._totalMass/volume).value_in_unit(unit.gram/unit.item/unit.milliliter))
if self._speed:
elapsedDays = (clockTime-self._initialClockTime)/86400
elapsedDays = (clockTime-self._initialClockTime)/86400.0
elapsedNs = (state.getTime()-self._initialSimulationTime).value_in_unit(unit.nanosecond)
if elapsedDays > 0.0:
values.append('%.3g' % (elapsedNs/elapsedDays))
else:
values.append('--')
if self._remainingTime:
elapsedSeconds = clockTime-self._initialClockTime
elapsedSteps = simulation.currentStep-self._initialSteps
......
......@@ -13,10 +13,10 @@ See https://simtk.org/home/pyopenmm for details"
%module (docstring=DOCSTRING) openmm
%include "typemaps.i"
%include "factory.i"
%include "std_string.i"
%include "std_iostream.i"
%include "typemaps.i"
%include "std_map.i"
%include "std_pair.i"
......
......@@ -156,3 +156,36 @@
$3 = &tempC;
}
%typemap(out) std::string OpenMM::Context::createCheckpoint{
// createCheckpoint returns a bytes object
$result = PyBytes_FromStringAndSize($1.c_str(), $1.length());
}
%typemap(in) std::string {
// if we have a C++ method that takes in a std::string, we're most happy to accept
// a python bytes object. But if the user passes in a unicode object we'll try
// to recover by encoding it to UTF-8 bytes
PyObject* temp = NULL;
char* c_str = NULL;
Py_ssize_t len = 0;
if (PyUnicode_Check($input)) {
temp = PyUnicode_AsUTF8String($input);
if (temp == NULL) {
SWIG_exception_fail(SWIG_TypeError, "'utf-8' codec can't decode byte");
}
PyBytes_AsStringAndSize(temp, &c_str, &len);
Py_XDECREF(temp);
} else if (PyBytes_Check($input)) {
PyBytes_AsStringAndSize($input, &c_str, &len);
} else {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
if (c_str == NULL) {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
$1 = std::string(c_str, len);
}
\ No newline at end of file
import unittest
import simtk.openmm as mm
class TestBytes(unittest.TestCase):
def test_createCheckpoint(self):
system = mm.System()
system.addParticle(1.0)
refPositions = [(0,0,0)]
context = mm.Context(system, mm.VerletIntegrator(0))
context.setPositions(refPositions)
chk = context.createCheckpoint()
# check that the return value of createCheckpoint is of type bytes (non-unicode)
assert isinstance(chk, bytes)
# set the positions to something random then reload the checkpoint, and
# make sure that the positions get restored correctly
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk)
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
# try encoding the checkpoint in utf-8. OpenMM should be able to handle this too
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk.decode('utf-8'))
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment