Commit dc8b4038 authored by Jason Swails's avatar Jason Swails
Browse files

Allow Quantity.sum (and other attributes, like mean, max, min, and std) to take

arguments and pass them to the numpy function if applicable. That way, users get
the full flexibility of the numpy API on those particular methods AND get the
added benefit that the result has the correct units (and it all runs fast).
parent 105b196e
......@@ -455,18 +455,24 @@ class Quantity(object):
new_value *= math.sqrt(unit_factor)
return Quantity(value=new_value, unit=new_unit)
def sum(self):
def sum(self, *args, **kwargs):
"""
Computes the sum of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError (same behavior as if
you tried to iterate over, for instance, an integer).
This function can take as arguments any arguments recognized by
`numpy.sum`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# This will be much faster for numpy arrays
mysum = self._value.sum()
mysum = self._value.sum(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.sum')
if len(self._value) == 0:
mysum = 0
else:
......@@ -475,31 +481,43 @@ class Quantity(object):
mysum += self._value[i]
return Quantity(mysum, self.unit)
def mean(self):
def mean(self, *args, **kwargs):
"""
Computes the mean of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.mean`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mean = self._value.mean()
mean = self._value.mean(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.mean')
mean = self.sum() / len(self._value)
return Quantity(mean, self.unit)
def std(self):
def std(self, *args, **kwargs):
"""
Computes the square root of the variance of a sequence, with the result
having the same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.std`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
std = self._value.std()
std = self._value.std(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.std')
mean = self.mean()
for val in self._value:
res = mean - val
......@@ -508,30 +526,40 @@ class Quantity(object):
std = math.sqrt(var)
return Quantity(std, self.unit)
def max(self):
def max(self, *args, **kwargs):
"""
Computes the maximum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.max`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mymax = self._value.max()
mymax = self._value.max(*args, **kwargs)
except AttributeError:
if args or kwargs:
raise TypeError('Unsupported arguments for Quantity.max')
mymax = max(self._value)
return Quantity(mymax, self.unit)
def min(self):
def min(self, *args, **kwargs):
"""
Computes the minimum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
This function can take as arguments any arguments recognized by
`numpy.min`. If arguments are passed to a non-numpy array, a TypeError
is raised
"""
try:
# Faster for numpy arrays
mymin = self._value.min()
mymin = self._value.min(*args, **kwargs)
except AttributeError:
mymin = min(self._value)
return Quantity(mymin, self.unit)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment