Commit a8cd8456 authored by peastman's avatar peastman
Browse files

Merge pull request #691 from swails/fix-sum

Fix compatibility with `unit.Quantity` and numpy arrays by having array methods accept and pass through relevant arguments.
parents 071b6fa0 5b4abbb5
......@@ -139,7 +139,15 @@ class Quantity(object):
first_item = iter(value).next()
# Avoid infinite recursion for string, because a one-character
# string is its own first element
if value == first_item:
try:
isstr = bool(value == first_item)
except ValueError:
# For numpy, value == first_item returns a numpy
# array of booleans, which cannot be evaluated for
# truthiness (a ValueError is raised). So in this
# case, we don't have a string
isstr = False
if isstr:
unit = dimensionless
else:
unit = Quantity(first_item).unit
......@@ -455,18 +463,24 @@ class Quantity(object):
new_value *= math.sqrt(unit_factor)
return Quantity(value=new_value, unit=new_unit)
def sum(self):
def sum(self, *args, **kwargs):
"""
Computes the sum of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError (same behavior as if
you tried to iterate over, for instance, an integer).
This function can take as arguments any arguments recognized by
`numpy.sum`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# This will be much faster for numpy arrays
mysum = self._value.sum()
mysum = self._value.sum(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.sum')
if len(self._value) == 0:
mysum = 0
else:
......@@ -475,31 +489,43 @@ class Quantity(object):
mysum += self._value[i]
return Quantity(mysum, self.unit)
def mean(self):
def mean(self, *args, **kwargs):
"""
Computes the mean of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.mean`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mean = self._value.mean()
mean = self._value.mean(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.mean')
mean = self.sum() / len(self._value)
return Quantity(mean, self.unit)
def std(self):
def std(self, *args, **kwargs):
"""
Computes the square root of the variance of a sequence, with the result
having the same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.std`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
std = self._value.std()
std = self._value.std(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.std')
mean = self.mean()
for val in self._value:
res = mean - val
......@@ -508,34 +534,57 @@ class Quantity(object):
std = math.sqrt(var)
return Quantity(std, self.unit)
def max(self):
def max(self, *args, **kwargs):
"""
Computes the maximum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.max`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mymax = self._value.max()
mymax = self._value.max(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.max')
mymax = max(self._value)
return Quantity(mymax, self.unit)
def min(self):
def min(self, *args, **kwargs):
"""
Computes the minimum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.min`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mymin = self._value.min()
mymin = self._value.min(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.min')
mymin = min(self._value)
return Quantity(mymin, self.unit)
def reshape(self, shape, order='C'):
"""
Same as numpy.ndarray.reshape, except the result is a Quantity with the
same units as the current object rather than a plain numpy.ndarray
"""
try:
return Quantity(self._value.reshape(shape, order=order), self.unit)
except AttributeError:
raise AttributeError('Only numpy array Quantity objects can be '
'reshaped')
def __abs__(self):
"""
Return absolute value of a Quantity.
......
......@@ -75,8 +75,27 @@ class TestNumpyCompatibility(unittest.TestCase):
f.addMap(10, energy)
size, energy_out = f.getMapParameters(0)
assert size == 10
self.assertEqual(size, 10)
np.testing.assert_array_almost_equal(energy, np.asarray(energy_out))
class TestNumpyUnits(unittest.TestCase):
def setUp(self):
self.data = unit.Quantity(np.arange(300), unit.nanometers)
def testNumpyAttributes(self):
d = self.data.reshape((100, 3))
self.assertTrue(unit.is_quantity(d) and d.unit is unit.nanometers)
self.assertTrue(unit.is_quantity(d.sum()))
self.assertTrue(unit.is_quantity(d.sum(axis=0)))
self.assertTrue(unit.is_quantity(d.std()))
self.assertTrue(unit.is_quantity(d.std(axis=0)))
self.assertTrue(unit.is_quantity(d.max()))
self.assertTrue(unit.is_quantity(d.max(axis=1)))
self.assertTrue(unit.is_quantity(d.min()))
self.assertTrue(unit.is_quantity(d.min(axis=0)))
self.assertTrue(unit.is_quantity(d.mean()))
self.assertTrue(unit.is_quantity(d.mean(axis=1)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment