Commit 4eb893cf authored by Jason Swails's avatar Jason Swails Committed by peastman
Browse files

Improve vec3 (#2230)

* Improve vec3

* Added a ``__neg__`` operator overload so that -Vec3(1, 2, 3) returns
  the same thing as Vec3(-1, -2, -3)
* Derived Vec3 from namedtuple instead of tuple. This allows you to
  access the 3 elements of the vector by name. i.e., vec.x, vec.y, vec.z

* Make sure we use floating point division all the time.
parent 2a0dbe01
...@@ -28,13 +28,14 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -28,13 +28,14 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import absolute_import from __future__ import absolute_import, division
__author__ = "Peter Eastman" __author__ = "Peter Eastman"
__version__ = "1.0" __version__ = "1.0"
import simtk.unit as unit from .. import unit
from collections import namedtuple
class Vec3(tuple): class Vec3(namedtuple('Vec3', ['x', 'y', 'z'])):
"""Vec3 is a 3-element tuple that supports many math operations.""" """Vec3 is a 3-element tuple that supports many math operations."""
def __new__(cls, x, y, z): def __new__(cls, x, y, z):
...@@ -47,36 +48,39 @@ class Vec3(tuple): ...@@ -47,36 +48,39 @@ class Vec3(tuple):
def __add__(self, other): def __add__(self, other):
"""Add two Vec3s.""" """Add two Vec3s."""
return Vec3(self[0]+other[0], self[1]+other[1], self[2]+other[2]) return Vec3(self.x+other[0], self.y+other[1], self.z+other[2])
def __radd__(self, other): def __radd__(self, other):
"""Add two Vec3s.""" """Add two Vec3s."""
return Vec3(self[0]+other[0], self[1]+other[1], self[2]+other[2]) return Vec3(self.x+other[0], self.y+other[1], self.z+other[2])
def __sub__(self, other): def __sub__(self, other):
"""Add two Vec3s.""" """Add two Vec3s."""
return Vec3(self[0]-other[0], self[1]-other[1], self[2]-other[2]) return Vec3(self.x-other[0], self.y-other[1], self.z-other[2])
def __rsub__(self, other): def __rsub__(self, other):
"""Add two Vec3s.""" """Add two Vec3s."""
return Vec3(other[0]-self[0], other[1]-self[1], other[2]-self[2]) return Vec3(other[0]-self.x, other[1]-self.y, other[2]-self.z)
def __mul__(self, other): def __mul__(self, other):
"""Multiply a Vec3 by a constant.""" """Multiply a Vec3 by a constant."""
if unit.is_unit(other): if unit.is_unit(other):
return unit.Quantity(self, other) return unit.Quantity(self, other)
return Vec3(other*self[0], other*self[1], other*self[2]) return Vec3(other*self.x, other*self.y, other*self.z)
def __rmul__(self, other): def __rmul__(self, other):
"""Multiply a Vec3 by a constant.""" """Multiply a Vec3 by a constant."""
if unit.is_unit(other): if unit.is_unit(other):
return unit.Quantity(self, other) return unit.Quantity(self, other)
return Vec3(other*self[0], other*self[1], other*self[2]) return Vec3(other*self.x, other*self.y, other*self.z)
def __div__(self, other): def __div__(self, other):
"""Divide a Vec3 by a constant.""" """Divide a Vec3 by a constant."""
return Vec3(self[0]/other, self[1]/other, self[2]/other) return Vec3(self.x/other, self.y/other, self.z/other)
__truediv__ = __div__ __truediv__ = __div__
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return Vec3(self[0], self[1], self[2]) return Vec3(self.x, self.y, self.z)
def __neg__(self):
return Vec3(-self.x, -self.y, -self.z)
""" Tests the Vec3 object """
from unittest import TestCase
from simtk.openmm import Vec3
class TestVectors(TestCase):
""" Tests the Vec3 type """
def testVec3Attributes(self):
vec1 = Vec3(1, 2, 3)
self.assertEqual(vec1.x, 1)
self.assertEqual(vec1.y, 2)
self.assertEqual(vec1.z, 3)
def testNegation(self):
vec1 = Vec3(1, 2, 3)
vec1_neg = Vec3(-1, -2, -3)
self.assertEqual(-vec1, vec1_neg)
def testVec3Equality(self):
vec1 = Vec3(1, 2, 3)
vec2 = Vec3(1, 2, 3)
self.assertEqual(vec1, vec2)
def testVec3Addition(self):
vec1 = Vec3(1, 2, 3)
vec2 = Vec3(4, 5, 6)
vec2_tup = (4, 5, 6)
result = Vec3(5, 7, 9)
self.assertEqual(vec1 + vec2, result)
self.assertEqual(vec1 + vec2_tup, result)
def testVec3Subtraction(self):
vec1 = Vec3(1, 2, 3)
vec2 = Vec3(3, 2, 1)
vec2_tup = (3, 2, 1)
result = Vec3(-2, 0, 2)
self.assertEqual(vec1 - vec2, result)
self.assertEqual(vec1 - vec2_tup, result)
self.assertEqual(vec2_tup - vec1, -result)
def testVec3Multiplication(self):
vec1 = Vec3(1, 2, 3)
factor = 2
result = Vec3(2, 4, 6)
self.assertEqual(vec1 * factor, result)
self.assertEqual(factor * vec1, result)
def testVec3Division(self):
vec1 = Vec3(4, 5, 6)
factor = 2
result = Vec3(2, 2.5, 3)
self.assertEqual(vec1 / factor, result)
with self.assertRaises(TypeError):
2 / vec1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment