Commit a5c00fa5 authored by Jason Swails's avatar Jason Swails
Browse files

Fix another small bug for unit handling with numpy arrays. Fixes the exception:

Traceback (most recent call last):
  File ".../python/tests/TestNumpyCompatibility.py", line 87, in testNumpyAttributes
    d = self.data.reshape((100, 3))
  File ".../simtk/unit/quantity.py", line 575, in reshape
    return Quantity(self._value.reshape(shape, order=order))
  File ".../simtk/unit/quantity.py", line 142, in __init__
    if value == first_item:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Also add a test for new unit numpy capabilities.
parent 019f5026
...@@ -139,7 +139,15 @@ class Quantity(object): ...@@ -139,7 +139,15 @@ class Quantity(object):
first_item = iter(value).next() first_item = iter(value).next()
# Avoid infinite recursion for string, because a one-character # Avoid infinite recursion for string, because a one-character
# string is its own first element # string is its own first element
if value == first_item: try:
isstr = bool(value == first_item)
except ValueError:
# For numpy, value == first_item returns a numpy
# array of booleans, which cannot be evaluated for
# truthiness (a ValueError is raised). So in this
# case, we don't have a string
isstr = False
if isstr:
unit = dimensionless unit = dimensionless
else: else:
unit = Quantity(first_item).unit unit = Quantity(first_item).unit
......
...@@ -78,5 +78,24 @@ class TestNumpyCompatibility(unittest.TestCase): ...@@ -78,5 +78,24 @@ class TestNumpyCompatibility(unittest.TestCase):
assert size == 10 assert size == 10
np.testing.assert_array_almost_equal(energy, np.asarray(energy_out)) np.testing.assert_array_almost_equal(energy, np.asarray(energy_out))
class TestNumpyUnits(unittest.TestCase):
def setUp(self):
self.data = unit.Quantity(np.arange(300), unit.nanometers)
def testNumpyAttributes(self):
d = self.data.reshape((100, 3))
self.assertTrue(unit.is_quantity(d))
self.assertTrue(unit.is_quantity(d.sum()))
self.assertTrue(unit.is_quantity(d.sum(axis=0)))
self.assertTrue(unit.is_quantity(d.std()))
self.assertTrue(unit.is_quantity(d.std(axis=0)))
self.assertTrue(unit.is_quantity(d.max()))
self.assertTrue(unit.is_quantity(d.max(axis=1)))
self.assertTrue(unit.is_quantity(d.min()))
self.assertTrue(unit.is_quantity(d.min(axis=0)))
self.assertTrue(unit.is_quantity(d.mean()))
self.assertTrue(unit.is_quantity(d.mean(axis=1)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment