Unverified Commit 39808f12 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Compare and track openmm.unit.BaseUnit objects by name (#4785)

* Compare and track openmm.unit.BaseUnit objects by name

* Retrieve _conversion_factors as a class attribute
parent c7216a94
......@@ -38,6 +38,8 @@ from __future__ import print_function, division, absolute_import
__author__ = "Christopher M. Bruns"
__version__ = "0.6"
import collections
class BaseUnit(object):
'''
Physical unit expressed in exactly one BaseDimension.
......@@ -47,23 +49,26 @@ class BaseUnit(object):
'''
__array_priority__ = 100
# Global table of conversion factors between base units
_conversion_factors = collections.defaultdict(lambda: collections.defaultdict(dict))
def __init__(self, base_dim, name, symbol):
"""Creates a new BaseUnit.
Parameters
- self: The newly created BaseUnit.
- base_dim: (BaseDimension) The dimension of the new unit, e.g. 'mass'
- name: (string) Name of the unit, e.g. "kilogram"
- name: (string) Name of the unit, e.g. "kilogram". This will be used
to distinguish between BaseUnit objects with the same dimension: two
BaseUnit objects with the same dimension that are given the same
name will be treated as equal to each other.
- symbol: (string) Symbol for the unit, e.g. 'kg'. This symbol will appear in
Quantity string descriptions.
"""
self.dimension = base_dim
self.name = name
self.symbol = symbol
self._conversion_factor_to = {}
self._conversion_factor_to[self] = 1.0
self._conversion_factor_to_by_name = {}
self._conversion_factor_to_by_name[self.name] = 1.0
BaseUnit._conversion_factors[self.dimension][self.name][self.name] = 1.0
def __lt__(self, other):
"""
......@@ -75,6 +80,14 @@ class BaseUnit(object):
# Second on conversion factor
return self.conversion_factor_to(other) < 1.0
def __eq__(self, other):
if not isinstance(other, BaseUnit):
return False
return self.dimension == other.dimension and self.name == other.name
def __hash__(self):
return hash((self.dimension, self.name))
def iter_base_dimensions(self):
"""
Returns a dictionary of BaseDimension:exponent pairs, describing the dimension of this unit.
......@@ -123,28 +136,25 @@ class BaseUnit(object):
if self.dimension != other.dimension:
raise TypeError('Cannot define conversion for BaseUnits with different dimensions.')
assert(factor != 0)
assert(not self is other)
assert(self != other)
conversion_factors = BaseUnit._conversion_factors[self.dimension]
conversion_factors_self = conversion_factors[self.name]
conversion_factors_other = conversion_factors[other.name]
# import all transitive conversions
self._conversion_factor_to[other] = factor
self._conversion_factor_to_by_name[other.name] = factor
for (unit, cfac) in other._conversion_factor_to.items():
if unit is self: continue
if unit in self._conversion_factor_to: continue
self._conversion_factor_to[unit] = factor * cfac
unit._conversion_factor_to[self] = pow(factor * cfac, -1)
self._conversion_factor_to_by_name[unit.name] = factor * cfac
unit._conversion_factor_to_by_name[self.name] = pow(factor * cfac, -1)
conversion_factors_self[other.name] = factor
for (unit_name, cfac) in conversion_factors_other.items():
if unit_name == self.name: continue
if unit_name in conversion_factors_self: continue
conversion_factors_self[unit_name] = factor * cfac
conversion_factors[unit_name][self.name] = pow(factor * cfac, -1)
# and for the other guy
invFac = pow(factor, -1.0)
other._conversion_factor_to[self] = invFac
other._conversion_factor_to_by_name[self.name] = invFac
for (unit, cfac) in self._conversion_factor_to.items():
if unit is other: continue
if unit in other._conversion_factor_to: continue
other._conversion_factor_to[unit] = invFac * cfac
unit._conversion_factor_to[other] = pow(invFac * cfac, -1)
other._conversion_factor_to_by_name[unit.name] = invFac * cfac
unit._conversion_factor_to_by_name[other.name] = pow(invFac * cfac, -1)
conversion_factors_other[self.name] = invFac
for (unit_name, cfac) in conversion_factors_self.items():
if unit_name == other.name: continue
if unit_name in conversion_factors_other: continue
conversion_factors_other[unit_name] = invFac * cfac
conversion_factors[unit_name][other.name] = pow(invFac * cfac, -1)
def conversion_factor_to(self, other):
"""Returns a conversion factor from this BaseUnit to another BaseUnit.
......@@ -159,9 +169,11 @@ class BaseUnit(object):
if self is other: return 1.0
if self.dimension != other.dimension:
raise TypeError('Cannot get conversion for BaseUnits with different dimensions.')
if not other.name in self._conversion_factor_to_by_name:
if self.name == other.name: return 1.0
conversion_factors_self = BaseUnit._conversion_factors[self.dimension][self.name]
if not other.name in conversion_factors_self:
raise LookupError('No conversion defined from BaseUnit "%s" to "%s".' % (self, other))
return self._conversion_factor_to_by_name[other.name]
return conversion_factors_self[other.name]
# run module directly for testing
if __name__=='__main__':
......
......@@ -64,7 +64,6 @@ class SiPrefix(object):
symbol = self.symbol + unit.symbol
name = self.prefix + unit.name
factor = self.factor
# TODO - check for existing BaseUnit with same name, symbol, and factor
new_base_unit = BaseUnit(unit.dimension, name, symbol)
new_base_unit.define_conversion_factor_to(unit, factor)
return new_base_unit
......@@ -73,7 +72,6 @@ class SiPrefix(object):
symbol = self.symbol + unit.symbol
name = self.prefix + unit.name
factor = self.factor * unit.factor
# TODO - check for existing BaseUnit with same name, symbol, and factor
return ScaledUnit(factor, unit.master, name, symbol)
elif isinstance(unit, Unit):
base_units = list(unit.iter_base_or_scaled_units())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment