Commit ff586f5f authored by peastman's avatar peastman
Browse files

Merge pull request #643 from swails/numpy-unit

[WIP] Add functions to Quantity to compute the max, min, standard deviation and average of a Quantity, returning a Quantity.
parents 77e33f17 f6d90066
...@@ -625,9 +625,9 @@ Examples ...@@ -625,9 +625,9 @@ Examples
>>> 1.2*meters < 72*centimeters >>> 1.2*meters < 72*centimeters
False False
>>> meter != None >>> meter is not None
True True
>>> meter == None >>> meter is None
False False
Examples Examples
......
...@@ -60,7 +60,7 @@ def zeros(m, n=None): ...@@ -60,7 +60,7 @@ def zeros(m, n=None):
[0, 0, 0] [0, 0, 0]
[0, 0, 0]] [0, 0, 0]]
""" """
if n == None: if n is None:
n = m n = m
result = [] result = []
for row in range(0, m): for row in range(0, m):
......
...@@ -114,7 +114,7 @@ class Quantity(object): ...@@ -114,7 +114,7 @@ class Quantity(object):
- unit: (Unit) the physical unit, e.g. simtk.unit.meters. - unit: (Unit) the physical unit, e.g. simtk.unit.meters.
""" """
# When no unit is specified, bend over backwards to handle all one-argument possibilities # When no unit is specified, bend over backwards to handle all one-argument possibilities
if unit == None: # one argument version, copied from UList if unit is None: # one argument version, copied from UList
if is_unit(value): if is_unit(value):
# Unit argument creates an empty list with that unit attached # Unit argument creates an empty list with that unit attached
unit = value unit = value
...@@ -167,7 +167,7 @@ class Quantity(object): ...@@ -167,7 +167,7 @@ class Quantity(object):
value = value * unit._value value = value * unit._value
unit = unit.unit unit = unit.unit
# Use empty list for unspecified values # Use empty list for unspecified values
if value == None: if value is None:
value = [] value = []
self._value = value self._value = value
...@@ -306,7 +306,7 @@ class Quantity(object): ...@@ -306,7 +306,7 @@ class Quantity(object):
value_factor = 1.0 value_factor = 1.0
canonical_units = {} # dict of dimensionTuple: (Base/ScaledUnit, exponent) canonical_units = {} # dict of dimensionTuple: (Base/ScaledUnit, exponent)
# Bias result toward guide units # Bias result toward guide units
if guide_unit != None: if guide_unit is not None:
for u, exponent in guide_unit.iter_base_or_scaled_units(): for u, exponent in guide_unit.iter_base_or_scaled_units():
d = u.get_dimension_tuple() d = u.get_dimension_tuple()
if d not in canonical_units: if d not in canonical_units:
...@@ -455,6 +455,82 @@ class Quantity(object): ...@@ -455,6 +455,82 @@ class Quantity(object):
new_value *= math.sqrt(unit_factor) new_value *= math.sqrt(unit_factor)
return Quantity(value=new_value, unit=new_unit) return Quantity(value=new_value, unit=new_unit)
def sum(self):
"""
Computes the sum of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError (same behavior as if
you tried to iterate over, for instance, an integer).
"""
try:
# This will be much faster for numpy arrays
mysum = self._value.sum()
except AttributeError:
mysum = sum(self._value)
return Quantity(mysum, self.unit)
def mean(self):
"""
Computes the mean of a sequence, with the result having the same unit as
the current sequence.
If the value is not iterable, it raises a TypeError
"""
try:
# Faster for numpy arrays
mean = self._value.mean()
except AttributeError:
mean = self.sum() / len(self._value)
return Quantity(mean, self.unit)
def std(self):
"""
Computes the square root of the variance of a sequence, with the result
having the same unit as the current sequence.
If the value is not iterable, it raises a TypeError
"""
try:
# Faster for numpy arrays
std = self._value.std()
except AttributeError:
mean = self.mean()
for val in self._value:
res = mean - val
var += res * res
var /= len(self._value)
std = math.sqrt(var)
return Quantity(std, self.unit)
def max(self):
"""
Computes the maximum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
"""
try:
# Faster for numpy arrays
mymax = self._value.max()
except AttributeError:
mymax = max(self._value)
return Quantity(mymax, self.unit)
def min(self):
"""
Computes the minimum value of the sequence, with the result having the
same unit as the current sequence.
If the value is not iterable, it raises a TypeError
"""
try:
# Faster for numpy arrays
mymin = self._value.min()
except AttributeError:
mymin = min(self._value)
return Quantity(mymin, self.unit)
def __abs__(self): def __abs__(self):
""" """
Return absolute value of a Quantity. Return absolute value of a Quantity.
......
...@@ -156,6 +156,10 @@ def sum(val): ...@@ -156,6 +156,10 @@ def sum(val):
>>> sum((2.0*meter, 30.0*centimeter)) >>> sum((2.0*meter, 30.0*centimeter))
Quantity(value=2.3, unit=meter) Quantity(value=2.3, unit=meter)
""" """
try:
return val.sum()
except AttributeError:
pass
if len(val) == 0: if len(val) == 0:
return 0 return 0
result = val[0] result = val[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment