Unverified Commit 39808f12 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Compare and track openmm.unit.BaseUnit objects by name (#4785)

* Compare and track openmm.unit.BaseUnit objects by name

* Retrieve _conversion_factors as a class attribute
parent c7216a94
...@@ -38,6 +38,8 @@ from __future__ import print_function, division, absolute_import ...@@ -38,6 +38,8 @@ from __future__ import print_function, division, absolute_import
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.6" __version__ = "0.6"
import collections
class BaseUnit(object): class BaseUnit(object):
''' '''
Physical unit expressed in exactly one BaseDimension. Physical unit expressed in exactly one BaseDimension.
...@@ -47,23 +49,26 @@ class BaseUnit(object): ...@@ -47,23 +49,26 @@ class BaseUnit(object):
''' '''
__array_priority__ = 100 __array_priority__ = 100
# Global table of conversion factors between base units
_conversion_factors = collections.defaultdict(lambda: collections.defaultdict(dict))
def __init__(self, base_dim, name, symbol): def __init__(self, base_dim, name, symbol):
"""Creates a new BaseUnit. """Creates a new BaseUnit.
Parameters Parameters
- self: The newly created BaseUnit. - self: The newly created BaseUnit.
- base_dim: (BaseDimension) The dimension of the new unit, e.g. 'mass' - base_dim: (BaseDimension) The dimension of the new unit, e.g. 'mass'
- name: (string) Name of the unit, e.g. "kilogram" - name: (string) Name of the unit, e.g. "kilogram". This will be used
to distinguish between BaseUnit objects with the same dimension: two
BaseUnit objects with the same dimension that are given the same
name will be treated as equal to each other.
- symbol: (string) Symbol for the unit, e.g. 'kg'. This symbol will appear in - symbol: (string) Symbol for the unit, e.g. 'kg'. This symbol will appear in
Quantity string descriptions. Quantity string descriptions.
""" """
self.dimension = base_dim self.dimension = base_dim
self.name = name self.name = name
self.symbol = symbol self.symbol = symbol
self._conversion_factor_to = {} BaseUnit._conversion_factors[self.dimension][self.name][self.name] = 1.0
self._conversion_factor_to[self] = 1.0
self._conversion_factor_to_by_name = {}
self._conversion_factor_to_by_name[self.name] = 1.0
def __lt__(self, other): def __lt__(self, other):
""" """
...@@ -75,6 +80,14 @@ class BaseUnit(object): ...@@ -75,6 +80,14 @@ class BaseUnit(object):
# Second on conversion factor # Second on conversion factor
return self.conversion_factor_to(other) < 1.0 return self.conversion_factor_to(other) < 1.0
def __eq__(self, other):
if not isinstance(other, BaseUnit):
return False
return self.dimension == other.dimension and self.name == other.name
def __hash__(self):
return hash((self.dimension, self.name))
def iter_base_dimensions(self): def iter_base_dimensions(self):
""" """
Returns a dictionary of BaseDimension:exponent pairs, describing the dimension of this unit. Returns a dictionary of BaseDimension:exponent pairs, describing the dimension of this unit.
...@@ -123,28 +136,25 @@ class BaseUnit(object): ...@@ -123,28 +136,25 @@ class BaseUnit(object):
if self.dimension != other.dimension: if self.dimension != other.dimension:
raise TypeError('Cannot define conversion for BaseUnits with different dimensions.') raise TypeError('Cannot define conversion for BaseUnits with different dimensions.')
assert(factor != 0) assert(factor != 0)
assert(not self is other) assert(self != other)
conversion_factors = BaseUnit._conversion_factors[self.dimension]
conversion_factors_self = conversion_factors[self.name]
conversion_factors_other = conversion_factors[other.name]
# import all transitive conversions # import all transitive conversions
self._conversion_factor_to[other] = factor conversion_factors_self[other.name] = factor
self._conversion_factor_to_by_name[other.name] = factor for (unit_name, cfac) in conversion_factors_other.items():
for (unit, cfac) in other._conversion_factor_to.items(): if unit_name == self.name: continue
if unit is self: continue if unit_name in conversion_factors_self: continue
if unit in self._conversion_factor_to: continue conversion_factors_self[unit_name] = factor * cfac
self._conversion_factor_to[unit] = factor * cfac conversion_factors[unit_name][self.name] = pow(factor * cfac, -1)
unit._conversion_factor_to[self] = pow(factor * cfac, -1)
self._conversion_factor_to_by_name[unit.name] = factor * cfac
unit._conversion_factor_to_by_name[self.name] = pow(factor * cfac, -1)
# and for the other guy # and for the other guy
invFac = pow(factor, -1.0) invFac = pow(factor, -1.0)
other._conversion_factor_to[self] = invFac conversion_factors_other[self.name] = invFac
other._conversion_factor_to_by_name[self.name] = invFac for (unit_name, cfac) in conversion_factors_self.items():
for (unit, cfac) in self._conversion_factor_to.items(): if unit_name == other.name: continue
if unit is other: continue if unit_name in conversion_factors_other: continue
if unit in other._conversion_factor_to: continue conversion_factors_other[unit_name] = invFac * cfac
other._conversion_factor_to[unit] = invFac * cfac conversion_factors[unit_name][other.name] = pow(invFac * cfac, -1)
unit._conversion_factor_to[other] = pow(invFac * cfac, -1)
other._conversion_factor_to_by_name[unit.name] = invFac * cfac
unit._conversion_factor_to_by_name[other.name] = pow(invFac * cfac, -1)
def conversion_factor_to(self, other): def conversion_factor_to(self, other):
"""Returns a conversion factor from this BaseUnit to another BaseUnit. """Returns a conversion factor from this BaseUnit to another BaseUnit.
...@@ -159,9 +169,11 @@ class BaseUnit(object): ...@@ -159,9 +169,11 @@ class BaseUnit(object):
if self is other: return 1.0 if self is other: return 1.0
if self.dimension != other.dimension: if self.dimension != other.dimension:
raise TypeError('Cannot get conversion for BaseUnits with different dimensions.') raise TypeError('Cannot get conversion for BaseUnits with different dimensions.')
if not other.name in self._conversion_factor_to_by_name: if self.name == other.name: return 1.0
conversion_factors_self = BaseUnit._conversion_factors[self.dimension][self.name]
if not other.name in conversion_factors_self:
raise LookupError('No conversion defined from BaseUnit "%s" to "%s".' % (self, other)) raise LookupError('No conversion defined from BaseUnit "%s" to "%s".' % (self, other))
return self._conversion_factor_to_by_name[other.name] return conversion_factors_self[other.name]
# run module directly for testing # run module directly for testing
if __name__=='__main__': if __name__=='__main__':
......
...@@ -64,7 +64,6 @@ class SiPrefix(object): ...@@ -64,7 +64,6 @@ class SiPrefix(object):
symbol = self.symbol + unit.symbol symbol = self.symbol + unit.symbol
name = self.prefix + unit.name name = self.prefix + unit.name
factor = self.factor factor = self.factor
# TODO - check for existing BaseUnit with same name, symbol, and factor
new_base_unit = BaseUnit(unit.dimension, name, symbol) new_base_unit = BaseUnit(unit.dimension, name, symbol)
new_base_unit.define_conversion_factor_to(unit, factor) new_base_unit.define_conversion_factor_to(unit, factor)
return new_base_unit return new_base_unit
...@@ -73,7 +72,6 @@ class SiPrefix(object): ...@@ -73,7 +72,6 @@ class SiPrefix(object):
symbol = self.symbol + unit.symbol symbol = self.symbol + unit.symbol
name = self.prefix + unit.name name = self.prefix + unit.name
factor = self.factor * unit.factor factor = self.factor * unit.factor
# TODO - check for existing BaseUnit with same name, symbol, and factor
return ScaledUnit(factor, unit.master, name, symbol) return ScaledUnit(factor, unit.master, name, symbol)
elif isinstance(unit, Unit): elif isinstance(unit, Unit):
base_units = list(unit.iter_base_or_scaled_units()) base_units = list(unit.iter_base_or_scaled_units())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment