"openmmapi/vscode:/vscode.git/clone" did not exist on "aa7bd1cf520fb75d7e8fbbf52609bd0dcdc91ed2"
Commit a8cd8456 authored by peastman's avatar peastman
Browse files

Merge pull request #691 from swails/fix-sum

Fix compatibility with `unit.Quantity` and numpy arrays by having array methods accept and pass through relevant arguments.
parents 071b6fa0 5b4abbb5
...@@ -139,7 +139,15 @@ class Quantity(object): ...@@ -139,7 +139,15 @@ class Quantity(object):
first_item = iter(value).next() first_item = iter(value).next()
# Avoid infinite recursion for string, because a one-character # Avoid infinite recursion for string, because a one-character
# string is its own first element # string is its own first element
if value == first_item: try:
isstr = bool(value == first_item)
except ValueError:
# For numpy, value == first_item returns a numpy
# array of booleans, which cannot be evaluated for
# truthiness (a ValueError is raised). So in this
# case, we don't have a string
isstr = False
if isstr:
unit = dimensionless unit = dimensionless
else: else:
unit = Quantity(first_item).unit unit = Quantity(first_item).unit
...@@ -455,18 +463,24 @@ class Quantity(object): ...@@ -455,18 +463,24 @@ class Quantity(object):
new_value *= math.sqrt(unit_factor) new_value *= math.sqrt(unit_factor)
return Quantity(value=new_value, unit=new_unit) return Quantity(value=new_value, unit=new_unit)
def sum(self): def sum(self, *args, **kwargs):
""" """
Computes the sum of a sequence, with the result having the same unit as Computes the sum of a sequence, with the result having the same unit as
the current sequence. the current sequence.
If the value is not iterable, it raises a TypeError (same behavior as if If the value is not iterable, it raises a TypeError (same behavior as if
you tried to iterate over, for instance, an integer). you tried to iterate over, for instance, an integer).
This function can take as arguments any arguments recognized by
`numpy.sum`. If arguments are passed to a non-numpy array, a TypeError
is raised
""" """
try: try:
# This will be much faster for numpy arrays # This will be much faster for numpy arrays
mysum = self._value.sum() mysum = self._value.sum(*args, **kwargs)
except AttributeError: except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.sum')
if len(self._value) == 0: if len(self._value) == 0:
mysum = 0 mysum = 0
else: else:
...@@ -475,31 +489,43 @@ class Quantity(object): ...@@ -475,31 +489,43 @@ class Quantity(object):
mysum += self._value[i] mysum += self._value[i]
return Quantity(mysum, self.unit) return Quantity(mysum, self.unit)
def mean(self): def mean(self, *args, **kwargs):
""" """
Computes the mean of a sequence, with the result having the same unit as Computes the mean of a sequence, with the result having the same unit as
the current sequence. the current sequence.
If the value is not iterable, it raises a TypeError If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.mean`. If arguments are passed to a non-numpy array, a TypeError
is raised
""" """
try: try:
# Faster for numpy arrays # Faster for numpy arrays
mean = self._value.mean() mean = self._value.mean(*args, **kwargs)
except AttributeError: except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.mean')
mean = self.sum() / len(self._value) mean = self.sum() / len(self._value)
return Quantity(mean, self.unit) return Quantity(mean, self.unit)
def std(self): def std(self, *args, **kwargs):
""" """
Computes the square root of the variance of a sequence, with the result Computes the square root of the variance of a sequence, with the result
having the same unit as the current sequence. having the same unit as the current sequence.
If the value is not iterable, it raises a TypeError If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.std`. If arguments are passed to a non-numpy array, a TypeError
is raised
""" """
try: try:
# Faster for numpy arrays # Faster for numpy arrays
std = self._value.std() std = self._value.std(*args, **kwargs)
except AttributeError: except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.std')
mean = self.mean() mean = self.mean()
for val in self._value: for val in self._value:
res = mean - val res = mean - val
...@@ -508,34 +534,57 @@ class Quantity(object): ...@@ -508,34 +534,57 @@ class Quantity(object):
std = math.sqrt(var) std = math.sqrt(var)
return Quantity(std, self.unit) return Quantity(std, self.unit)
def max(self): def max(self, *args, **kwargs):
""" """
Computes the maximum value of the sequence, with the result having the Computes the maximum value of the sequence, with the result having the
same unit as the current sequence. same unit as the current sequence.
If the value is not iterable, it raises a TypeError If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.max`. If arguments are passed to a non-numpy array, a TypeError
is raised
""" """
try: try:
# Faster for numpy arrays # Faster for numpy arrays
mymax = self._value.max() mymax = self._value.max(*args, **kwargs)
except AttributeError: except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.max')
mymax = max(self._value) mymax = max(self._value)
return Quantity(mymax, self.unit) return Quantity(mymax, self.unit)
def min(self): def min(self, *args, **kwargs):
""" """
Computes the minimum value of the sequence, with the result having the Computes the minimum value of the sequence, with the result having the
same unit as the current sequence. same unit as the current sequence.
If the value is not iterable, it raises a TypeError If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.min`. If arguments are passed to a non-numpy array, a TypeError
is raised
""" """
try: try:
# Faster for numpy arrays # Faster for numpy arrays
mymin = self._value.min() mymin = self._value.min(*args, **kwargs)
except AttributeError: except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.min')
mymin = min(self._value) mymin = min(self._value)
return Quantity(mymin, self.unit) return Quantity(mymin, self.unit)
def reshape(self, shape, order='C'):
"""
Same as numpy.ndarray.reshape, except the result is a Quantity with the
same units as the current object rather than a plain numpy.ndarray
"""
try:
return Quantity(self._value.reshape(shape, order=order), self.unit)
except AttributeError:
raise AttributeError('Only numpy array Quantity objects can be '
'reshaped')
def __abs__(self): def __abs__(self):
""" """
Return absolute value of a Quantity. Return absolute value of a Quantity.
......
...@@ -75,8 +75,27 @@ class TestNumpyCompatibility(unittest.TestCase): ...@@ -75,8 +75,27 @@ class TestNumpyCompatibility(unittest.TestCase):
f.addMap(10, energy) f.addMap(10, energy)
size, energy_out = f.getMapParameters(0) size, energy_out = f.getMapParameters(0)
assert size == 10 self.assertEqual(size, 10)
np.testing.assert_array_almost_equal(energy, np.asarray(energy_out)) np.testing.assert_array_almost_equal(energy, np.asarray(energy_out))
class TestNumpyUnits(unittest.TestCase):
def setUp(self):
self.data = unit.Quantity(np.arange(300), unit.nanometers)
def testNumpyAttributes(self):
d = self.data.reshape((100, 3))
self.assertTrue(unit.is_quantity(d) and d.unit is unit.nanometers)
self.assertTrue(unit.is_quantity(d.sum()))
self.assertTrue(unit.is_quantity(d.sum(axis=0)))
self.assertTrue(unit.is_quantity(d.std()))
self.assertTrue(unit.is_quantity(d.std(axis=0)))
self.assertTrue(unit.is_quantity(d.max()))
self.assertTrue(unit.is_quantity(d.max(axis=1)))
self.assertTrue(unit.is_quantity(d.min()))
self.assertTrue(unit.is_quantity(d.min(axis=0)))
self.assertTrue(unit.is_quantity(d.mean()))
self.assertTrue(unit.is_quantity(d.mean(axis=1)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment