Commit 306fe04e authored by peastman's avatar peastman
Browse files

Merge pull request #378 from rmcgibbo/bytes

loadCheckpoint/createCheckpoint should use bytes
parents fe129ae2 40b969f2
...@@ -13,10 +13,10 @@ See https://simtk.org/home/pyopenmm for details" ...@@ -13,10 +13,10 @@ See https://simtk.org/home/pyopenmm for details"
%module (docstring=DOCSTRING) openmm %module (docstring=DOCSTRING) openmm
%include "typemaps.i"
%include "factory.i" %include "factory.i"
%include "std_string.i" %include "std_string.i"
%include "std_iostream.i" %include "std_iostream.i"
%include "typemaps.i"
%include "std_map.i" %include "std_map.i"
%include "std_pair.i" %include "std_pair.i"
......
...@@ -156,3 +156,36 @@ ...@@ -156,3 +156,36 @@
$3 = &tempC; $3 = &tempC;
} }
%typemap(out) std::string OpenMM::Context::createCheckpoint{
// createCheckpoint returns a bytes object
$result = PyBytes_FromStringAndSize($1.c_str(), $1.length());
}
%typemap(in) std::string {
// if we have a C++ method that takes in a std::string, we're most happy to accept
// a python bytes object. But if the user passes in a unicode object we'll try
// to recover by encoding it to UTF-8 bytes
PyObject* temp = NULL;
char* c_str = NULL;
Py_ssize_t len = 0;
if (PyUnicode_Check($input)) {
temp = PyUnicode_AsUTF8String($input);
if (temp == NULL) {
SWIG_exception_fail(SWIG_TypeError, "'utf-8' codec can't decode byte");
}
PyBytes_AsStringAndSize(temp, &c_str, &len);
Py_XDECREF(temp);
} else if (PyBytes_Check($input)) {
PyBytes_AsStringAndSize($input, &c_str, &len);
} else {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
if (c_str == NULL) {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
$1 = std::string(c_str, len);
}
\ No newline at end of file
import unittest
import simtk.openmm as mm
class TestBytes(unittest.TestCase):
def test_createCheckpoint(self):
system = mm.System()
system.addParticle(1.0)
refPositions = [(0,0,0)]
context = mm.Context(system, mm.VerletIntegrator(0))
context.setPositions(refPositions)
chk = context.createCheckpoint()
# check that the return value of createCheckpoint is of type bytes (non-unicode)
assert isinstance(chk, bytes)
# set the positions to something random then reload the checkpoint, and
# make sure that the positions get restored correctly
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk)
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
# try encoding the checkpoint in utf-8. OpenMM should be able to handle this too
context.setPositions([(12345, 12345, 123451)])
context.loadCheckpoint(chk.decode('utf-8'))
newPositions = context.getState(getPositions=True).getPositions()._value
assert newPositions == refPositions
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment