Commit c1e5ec2f authored by peastman's avatar peastman
Browse files

Add ability to enforce units on input arguments

parent a9ea2b9b
......@@ -446,6 +446,7 @@ class SwigInputBuilder:
#write only non Constructor and Destructor methods and python mods
self.fOut.write("\n")
methodsWithOutputArgs = set()
for items in methodList:
clearOutput=""
(shortClassName, memberNode,
......@@ -479,6 +480,7 @@ class SwigInputBuilder:
(INDENT, simpleType, pType, pName))
clearOutput = "%s%s%%clear %s %s;\n" \
% (clearOutput, INDENT, pType, pName)
methodsWithOutputArgs.add((shortClassName, methName))
mArgsstring = getText("argsstring", memberNode)
try:
......@@ -503,7 +505,10 @@ class SwigInputBuilder:
paramList = findNodes(memberNode, 'param')
# write pythonprepend blocks
mArgsstring = getText("argsstring", memberNode)
if isConstructors:
mArgsstring = '' # specifying args to constructors seems to prevent append and prepend from working
else:
mArgsstring = getText("argsstring", memberNode)
if self.fOutPythonprepend and \
len(paramList) and \
mArgsstring.find('=0') < 0:
......@@ -523,6 +528,28 @@ class SwigInputBuilder:
s = ("the %s object does not own its corresponding OpenMM object"
% self.__class__.__name__)
raise Exception(s)'''.format(argName=argName)
# Convert input arguments to the proper units, if specified.
if key not in methodsWithOutputArgs:
if key in self.configModule.UNITS:
argUnits=self.configModule.UNITS[key][1]
elif ("*", methName) in self.configModule.UNITS:
argUnits=self.configModule.UNITS[("*", methName)][1]
else:
argUnits = ()
if len(argUnits) > 0 and (self.SWIG_COMPACT_ARGUMENTS or isConstructors):
textInside += '''
args = list(args)'''
for i, units in enumerate(argUnits):
if units is not None:
if self.SWIG_COMPACT_ARGUMENTS or isConstructors:
argName = 'args[%s]' % i
else:
argName = getText('declname', paramList[i])
textInside += '''
if unit.is_quantity({argName}):
{argName} = {argName}.value_in_unit({units})'''.format(argName=argName, units=units)
for argNum in self.configModule.REQUIRE_ORDERED_SET.get(key, []):
if self.SWIG_COMPACT_ARGUMENTS:
argName = 'args[%s]' % argNum
......@@ -568,7 +595,7 @@ class SwigInputBuilder:
% (addText, INDENT, valueUnits[0])
for vUnit in valueUnits[1]:
if vUnit is not None:
if vUnit is not None and key in methodsWithOutputArgs:
addText = "%s%sval[%s]=unit.Quantity(val[%s], %s)\n" \
% (addText, INDENT, index, index, vUnit)
index+=1
......
......@@ -117,9 +117,6 @@ SKIP_METHODS = [('State', 'getPositions'),
# The build script assumes method args that are non-const references are
# used to output values. This list gives excpetions to this rule.
NO_OUTPUT_ARGS = [('LocalEnergyMinimizer', 'minimize', 'context'),
('System', 'getDefaultPeriodicBoxVectors', 'a'),
('System', 'getDefaultPeriodicBoxVectors', 'b'),
('System', 'getDefaultPeriodicBoxVectors', 'c'),
('Platform', 'setPropertyValue', 'context'),
('AmoebaTorsionTorsionForce', 'setTorsionTorsionGrid', 'grid'),
('AmoebaVdwForce', 'setParticleExclusions', 'exclusions'),
......@@ -175,8 +172,14 @@ UNITS = {
("*", "getDefaultPressureX") : ("unit.bar", ()),
("*", "getDefaultPressureY") : ("unit.bar", ()),
("*", "getDefaultPressureZ") : ("unit.bar", ()),
("*", "setDefaultPressure") : (None, ("unit.bar",)),
("*", "setDefaultPressureX") : (None, ("unit.bar",)),
("*", "setDefaultPressureY") : (None, ("unit.bar",)),
("*", "setDefaultPressureZ") : (None, ("unit.bar",)),
("*", "getDefaultSurfaceTension") : ("unit.bar*unit.nanometer", ()),
("*", "setDefaultSurfaceTension") : (None, ("unit.bar*unit.nanometer",)),
("*", "getDefaultTemperature") : ("unit.kelvin", ()),
("*", "setDefaultTemperature") : (None, ("unit.kelvin",)),
("*", "getErrorTolerance") : (None, ()),
("*", "getEwaldErrorTolerance") : (None, ()),
("*", "getFriction") : ("1/unit.picosecond", ()),
......@@ -459,6 +462,7 @@ UNITS = {
("System", "getVirtualSite") : (None, ()),
("DrudeLangevinIntegrator", "getDrudeTemperature") : ("unit.kelvin", ()),
("DrudeLangevinIntegrator", "getMaxDrudeDistance") : ("unit.nanometer", ()),
("MonteCarloMembraneBarostat", "MonteCarloMembraneBarostat") : (None, ("unit.bar", "unit.bar*unit.nanometer", "unit.kelvin", None, None, None)),
("MonteCarloMembraneBarostat", "getXYMode") : (None, ()),
("MonteCarloMembraneBarostat", "getZMode") : (None, ()),
("DrudeLangevinIntegrator", "getDrudeFriction") : ("1/unit.picosecond", ()),
......
......@@ -1176,6 +1176,28 @@ class TestAPIUnits(unittest.TestCase):
self.assertEqual(force.getDefaultTemperature(), 298.15*kelvin)
self.assertAlmostEqualUnit(force.getDefaultCollisionFrequency(), 1/picosecond)
def testMonteCarloMembraneBarostat(self):
""" Tests the MonteCarloMembraneBarostat API features """
force = MonteCarloMembraneBarostat(1.0, 1.5, 300, MonteCarloMembraneBarostat.XYAnisotropic, MonteCarloMembraneBarostat.ZFixed, 25)
self.assertEqual(force.getDefaultPressure(), 1.0*bar)
self.assertEqual(force.getDefaultSurfaceTension(), 1.5*bar*nanometer)
self.assertEqual(force.getDefaultTemperature(), 300*kelvin)
self.assertEqual(force.getXYMode(), MonteCarloMembraneBarostat.XYAnisotropic)
self.assertEqual(force.getZMode(), MonteCarloMembraneBarostat.ZFixed)
self.assertEqual(force.getFrequency(), 25)
force = MonteCarloMembraneBarostat(1.1*bar, 2.0*bar*nanometer, 350*kelvin, MonteCarloMembraneBarostat.XYAnisotropic, MonteCarloMembraneBarostat.ZFixed, 25)
self.assertEqual(force.getDefaultPressure(), 1.1*bar)
self.assertEqual(force.getDefaultSurfaceTension(), 2.0*bar*nanometer)
self.assertEqual(force.getDefaultTemperature(), 350*kelvin)
force.setDefaultPressure(1.2*bar)
force.setDefaultSurfaceTension(2.5*bar*nanometer)
force.setDefaultTemperature(298.15)
self.assertEqual(force.getDefaultPressure(), 1.2*bar)
self.assertEqual(force.getDefaultSurfaceTension(), 2.5*bar*nanometer)
self.assertEqual(force.getDefaultTemperature(), 298.15*kelvin)
def testDrudeSCFIntegrator(self):
""" Tests the DrudeSCFIntegrator API features """
integrator = DrudeSCFIntegrator(0.002)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment