"docs-source/vscode:/vscode.git/clone" did not exist on "bc6fe72928817f33edfbec3150e40bcfc636c2c7"
Commit 40b969f2 authored by Robert McGibbon's avatar Robert McGibbon
Browse files

create/load both working

parent 67878bab
...@@ -13,10 +13,10 @@ See https://simtk.org/home/pyopenmm for details" ...@@ -13,10 +13,10 @@ See https://simtk.org/home/pyopenmm for details"
%module (docstring=DOCSTRING) openmm %module (docstring=DOCSTRING) openmm
%include "typemaps.i"
%include "factory.i" %include "factory.i"
%include "std_string.i" %include "std_string.i"
%include "std_iostream.i" %include "std_iostream.i"
%include "typemaps.i"
%include "std_map.i" %include "std_map.i"
%include "std_pair.i" %include "std_pair.i"
......
...@@ -157,5 +157,35 @@ ...@@ -157,5 +157,35 @@
} }
%typemap(out) std::string OpenMM::Context::createCheckpoint{ %typemap(out) std::string OpenMM::Context::createCheckpoint{
$result = PyBytes_FromString($1.c_str()); // createCheckpoint returns a bytes object
$result = PyBytes_FromStringAndSize($1.c_str(), $1.length());
}
%typemap(in) std::string {
// if we have a C++ method that takes in a std::string, we're most happy to accept
// a python bytes object. But if the user passes in a unicode object we'll try
// to recover by encoding it to UTF-8 bytes
PyObject* temp = NULL;
char* c_str = NULL;
Py_ssize_t len = 0;
if (PyUnicode_Check($input)) {
temp = PyUnicode_AsUTF8String($input);
if (temp == NULL) {
SWIG_exception_fail(SWIG_TypeError, "'utf-8' codec can't decode byte");
}
PyBytes_AsStringAndSize(temp, &c_str, &len);
Py_XDECREF(temp);
} else if (PyBytes_Check($input)) {
PyBytes_AsStringAndSize($input, &c_str, &len);
} else {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
if (c_str == NULL) {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
$1 = std::string(c_str, len);
} }
\ No newline at end of file
...@@ -4,12 +4,33 @@ import simtk.openmm as mm ...@@ -4,12 +4,33 @@ import simtk.openmm as mm
class TestBytes(unittest.TestCase): class TestBytes(unittest.TestCase):
def test_createCheckpoint(self): def test_createCheckpoint(self):
# check that the return value of createCheckpoint is of type bytes (non-unicode)
system = mm.System() system = mm.System()
system.addParticle(1.0) system.addParticle(1.0)
mm.Context(system, mm.VerletIntegrator(0)) refPositions = [(0,0,0)]
chk = mm.Context(system, mm.VerletIntegrator(0)).createCheckpoint()
context = mm.Context(system, mm.VerletIntegrator(0))
context.setPositions(refPositions)
chk = context.createCheckpoint()
# check that the return value of createCheckpoint is of type bytes (non-unicode)
assert isinstance(chk, bytes) assert isinstance(chk, bytes)
# set the positions to something random then reload the checkpoint, and
# make sure that the positions get restored correctly
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk)
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
# try encoding the checkpoint in utf-8. OpenMM should be able to handle this too
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk.decode('utf-8'))
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment